ops/mp: implement aimdo

Implement a model patcher and caster for aimdo.

A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens.

As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute).

Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods.

Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive).
This commit is contained in:
Rattus 2026-01-13 15:36:09 +10:00
parent f74661edc6
commit f75765721d
3 changed files with 400 additions and 12 deletions

View File

@ -27,8 +27,12 @@ import weakref
import gc
import os
from contextlib import nullcontext
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -1157,7 +1161,59 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = nullcontext()
if stream is not None:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
with wf_context:
for tensor in tensors:
dest_view = dest_views.pop(0)
if tensor is None:
continue
dest_view.copy_(tensor, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
r.copy_(weight, non_blocking=non_blocking)
#FIXME: remove hooks before PR
if hasattr(weight, "comfy_hook"):
dtype = r.dtype
r = weight.comfy_hook(r)
if r.dtype != dtype:
r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key))
if signature is not None:
v_tensor.copy_(r)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:

View File

@ -39,6 +39,7 @@ from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
import comfy_aimdo.model_vbar
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -1397,3 +1398,216 @@ class ModelPatcher:
self.unpin_all_weights()
self.detach(unpatch_all=False)
class ModelPatcherDynamic(ModelPatcher):
def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False):
if comfy.model_management.is_device_cpu(load_device):
#reroute to default MP for CPUs
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
return super().__new__(cls)
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
assert load_device is not None
def is_dynamic(self):
return True
def _vbar_get(self, create=False):
if self.load_device == torch.device("cpu"):
return None
vbar = self.model.dynamic_vbars.get(self.load_device, None)
if create and vbar is None:
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index)
self.model.dynamic_vbars[self.load_device] = vbar
return vbar
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
def pin_weight_to_device(self, key):
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
def unpin_weight(self, key):
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
def unpin_all_weights(self):
pass
def memory_required(self, input_shape):
#Pad this significantly. We are trying to get away from precise estimates. This
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
#doesn't need to be forced at this stage. The only thing you could do would be patch
#it all on CPU which consumes huge RAM.
assert not force_patch_weights
#Full load doesn't make sense as we dont actually have any loader capability here and
#now.
assert not full_load;
assert device_to == self.load_device
num_patches = 0
allocated_size = 0
with self.use_ejected():
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
if vbar is not None:
vbar.prioritize()
#We have way more tools for acceleration on comfy weight offloading, so always
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
for x in loading:
_, _, _, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
item._v_signature = None
if dirty:
comfy.pinned_memory.unpin_memory(item)
def setup_param(self, m, n, param_key):
nonlocal num_patches
key = "{}.{}".format(n, param_key)
weight_function = []
weight, _, _ = get_key_weight(self.model, key)
if key in self.patches:
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
setattr(m, param_key + "_lowvram_function", None)
if key in self.weight_wrapper_patches:
weight_function.extend(self.weight_wrapper_patches[key])
setattr(m, param_key + "_function", weight_function)
return comfy.memory_management.vram_aligned_size(weight)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
key = "{}.{}".format(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
weight_size = weight.numel() * weight.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
allocated_size += weight_size
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
#These are all super dangerous. Who knows what the custom nodes actually do here...
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
self.apply_hooks(self.forced_hooks, force_apply=True)
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
assert not force_patch_weights #See above
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of
#the control of proper model_managment. If you are a custom node author reading
#this, the correct pattern is to call load_models_gpu() to get a proper
#managed load of your model.
assert not load_weights
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
def unpatch_model(self, device_to=None, unpatch_weights=True):
super().unpatch_model(device_to=None, unpatch_weights=False)
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
self.unpatch_model(self.offload_device, unpatch_weights=False)
self.patch_model(load_weights=False)
try:
self.load(device_to, dirty=dirty)
except Exception as e:
self.detach()
raise e
#ModelPatcher::partially_load returns a number on what got loaded but
#nothing in core uses this and we have no data in the Dynamic world. Hit
#the custom node devs with a None rather than a 0 that would mislead any
#logic they might have.
return None
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
assert False #Should be unreachable - we dont ever cache in the new implementation
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
CoreModelPatcher = ModelPatcher

View File

@ -24,6 +24,11 @@ import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
import comfy.utils
import comfy_aimdo.model_vbar
import comfy_aimdo.torch
def run_every_op():
if torch.compiler.is_compiling():
@ -73,7 +78,108 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
xfer_source = [ s.weight, s.bias ]
pin = comfy.pinned_memory.get_pin(s)
if pin is not None:
xfer_source = [ pin ]
resident = True #If pinned data exists, it always has LowVram already applied
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
pin = None
if signature is not None:
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
#reclaim the pin previously used for offload.
comfy.pinned_memory.unpin_memory(s)
elif not resident:
#prepare a new pin
assert comfy.pinned_memory.get_pin(s) is None
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
weight = params[0]
bias = params[1]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
hook_fn = getattr(s, param_key + "_hooks", None)
fns = getattr(s, param_key + "_function", [])
orig = x
q_layout = None
def to_dequant(tensor, dtype):
tensor = tensor.to(dtype=dtype)
if isinstance(tensor, QuantizedTensor):
tensor = tensor.dequantize()
return tensor
if orig.dtype != dtype or len(fns) > 0:
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
else:
y = x
if update_weight:
orig.copy_(y)
for f in fns:
x = f(x)
return x
update_weight = signature is not None or pin is not None
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
if pin is not None:
xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0]
if offload_stream is not None:
#FIXME: if post cast didnt do anything this sync is un-needed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@ -88,6 +194,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
@ -108,8 +219,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight = params[0]
bias = params[1]
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
@ -146,14 +255,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
return
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
if device is None:
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
@ -668,8 +783,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@ -679,6 +794,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
input_shape = input.shape
reshaped_3d = False
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
@ -697,7 +814,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input)
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D
if reshaped_3d: