From f75765721d008a43581492582b4244efa827ee9d Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:36:09 +1000 Subject: [PATCH] ops/mp: implement aimdo Implement a model patcher and caster for aimdo. A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens. As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute). Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods. Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive). --- comfy/model_management.py | 56 ++++++++++ comfy/model_patcher.py | 214 ++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 142 ++++++++++++++++++++++--- 3 files changed, 400 insertions(+), 12 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b1bf3bd59..c5a22e04c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -27,8 +27,12 @@ import weakref import gc import os from contextlib import nullcontext +import comfy.utils import comfy.quant_ops +import comfy_aimdo.torch +import comfy_aimdo.model_vbar + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1157,7 +1161,59 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + #FIXME: remove hooks before PR + if hasattr(weight, "comfy_hook"): + dtype = r.dtype + r = weight.comfy_hook(r) + if r.dtype != dtype: + r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key)) + + if signature is not None: + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0263f133d..77e7eec90 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -39,6 +39,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -1397,3 +1398,216 @@ class ModelPatcher: self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False): + if comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + pass + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load; + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + if dirty: + comfy.pinned_memory.unpin_memory(item) + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + return comfy.memory_management.vram_aligned_size(weight) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + weight_size = weight.numel() * weight.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index 2c35b21d0..825f9ae82 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,11 @@ import comfy.float import comfy.rmsnorm import json import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -73,7 +78,108 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + resident = True #If pinned data exists, it always has LowVram already applied + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + pin = None + if signature is not None: + #If we are able to increase our load level (e.g. user reduces resolution or batch number) + #reclaim the pin previously used for offload. + comfy.pinned_memory.unpin_memory(s) + elif not resident: + #prepare a new pin + assert comfy.pinned_memory.get_pin(s) is None + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + + params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + hook_fn = getattr(s, param_key + "_hooks", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + q_layout = None + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None or pin is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + if pin is not None: + xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] + if offload_stream is not None: + #FIXME: if post cast didnt do anything this sync is un-needed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -88,6 +194,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) @@ -108,8 +219,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight = params[0] bias = params[1] - non_blocking = comfy.model_management.device_supports_non_blocking(device) - weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 @@ -146,14 +255,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -668,8 +783,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -679,6 +794,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -697,7 +814,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: