From caf6c6aada3125afa3cfd73144b98ac025e6b2b8 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 1 Jan 2026 00:36:06 +1000 Subject: [PATCH 01/18] Reduce RAM and compute time in model saving with Loras Get the model saving logic away from force_patch_weights and instead do the patching JIT during safetensors saving. Firstly switch off force_patch_weights in the load for save which avoids creating CPU side tensors with loras calculated. Then at save time, wrap the tensor to catch safetensors call to .to() and patch it live. This avoids having to ever have a lora-calculated copy of offloaded weights on the CPU. Also take advantage of the presence of the GPU when doing this Lora calculation. The former force_patch_weights would just do eveyrthing on the CPU. Its generally faster to go the GPU and back even if its just a Lora application. --- comfy/model_base.py | 9 +++---- comfy/model_patcher.py | 55 ++++++++++++++++++++++++++++++++++++------ comfy/sd.py | 4 +-- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..b9e2abdb9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -775,8 +772,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..30ca39b2a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import types from typing import Callable, Optional import torch @@ -212,6 +213,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -611,14 +633,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -632,12 +654,14 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1355,6 +1379,23 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..b71157343 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1575,9 +1575,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] From 967f848df273acb059f1bee5f17b5aeca19c58f2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 22 Dec 2025 19:40:22 +1000 Subject: [PATCH 02/18] ops: Do bias dtype conversion on compute stream For consistency with weights. --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 415c39e92..6aec1cdaf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -102,7 +102,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +110,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) From babccae951766e9375cc12fb12c9c9a61d5fcbea Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 23 Dec 2025 11:01:32 +1000 Subject: [PATCH 03/18] mm: Implement cast buffer allocations --- comfy/memory_management.py | 51 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 57 ++++++++++++++++++++++++++++++++++++-- comfy/ops.py | 20 +++++++++++-- cuda_malloc.py | 5 +++- 4 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 comfy/memory_management.py diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..f8bca5263 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,51 @@ +import torch +from comfy.quant_ops import QuantizedTensor + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..790236ede 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,8 @@ import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.quant_ops class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -732,6 +734,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -1051,6 +1056,49 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + torch.cuda.synchronize() + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + STREAM_CAST_BUFFERS.clear() + torch.cuda.synchronize() + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1141,7 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1160,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1557,6 +1607,7 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): + torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ops.py b/comfy/ops.py index 6aec1cdaf..2c35b21d0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import json +import comfy.memory_management def run_every_op(): if torch.compiler.is_compiling(): @@ -93,16 +94,29 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] + non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..00ee7b633 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -86,7 +86,10 @@ if not args.cuda_malloc: pass -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" From f1e8ccae5c36896933c5f115bcb53692c90da27b Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 8 Jan 2026 18:21:50 +1000 Subject: [PATCH 04/18] move string_to_seed to utils.py This needs to be visible by ops which may want to do stochastic rounding on the fly. --- comfy/model_patcher.py | 21 ++++----------------- comfy/utils.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 30ca39b2a..46dcf5be8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,19 +40,6 @@ from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF - def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -653,7 +640,7 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -661,7 +648,7 @@ class ModelPatcher: else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1341,10 +1328,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device diff --git a/comfy/utils.py b/comfy/utils.py index ffa98c9b1..635dc4146 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1267,3 +1267,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF From 439c178c2cbb5cb07445314ab99384b86a3c3c5e Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:16:41 +1000 Subject: [PATCH 05/18] pinned_memory: add python Add a python for managing pinned memory of the weight/bias module level. This allocates, pins and attached a tensor to a module for the pin for this module. It does not set the weight, just allocates a singular ram buffer for population and bulk DMA transfer. --- comfy/model_management.py | 2 +- comfy/pinned_memory.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 comfy/pinned_memory.py diff --git a/comfy/model_management.py b/comfy/model_management.py index 790236ede..21761d971 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1185,7 +1185,7 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..be303b4f1 --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,34 @@ +import torch +import logging +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + params = [ module.weight, module.bias ] + size = comfy.memory_management.vram_aligned_size(params) + try: + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + except: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + return size From f5113675290ec6cae2051f599d592022bdb8e0fb Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:20:47 +1000 Subject: [PATCH 06/18] mp: wrap get_free_memory Dynamic load needs to adjust these numbers based on future movements, so wrap this in a MP API. --- comfy/model_patcher.py | 3 +++ comfy/samplers.py | 2 +- comfy/sd.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46dcf5be8..24e3e5fcd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -293,6 +293,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..d495ca203 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -260,7 +260,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index b71157343..f83164d23 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -797,7 +797,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -871,7 +871,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None From b9ee4c6ee5d0fa0efb03c09d2d8fe05dac1e946c Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:23:49 +1000 Subject: [PATCH 07/18] mp/mm: APi expansions for dynamic loading Add two api expansions, a flag for whether a model patcher is dynamic a a very basic RAM freeing system. Implement the semantics of the dynamic model patcher which never frees VRAM ahead of time for the sake of another dynamic model patcher. At the same time add an API for clearing out pins on a reservation of model size x2 heuristic, as pins consume RAM in their own right in the dynamic patcher. This is actually less about OOMing RAM and more about performance, as with assign=True load semantics there needs to be plenty headroom for the OS to load models to dosk cache on demand so err on the side of kicking old pins out. --- comfy/model_management.py | 34 +++++++++++++++++++++++++--------- comfy/model_patcher.py | 6 ++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21761d971..b1bf3bd59 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -594,7 +594,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -609,15 +609,22 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - psutil.virtual_memory().available + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + continue logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) + if ram_to_free > 0: + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -652,7 +659,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -678,19 +688,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 24e3e5fcd..57cec2747 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -278,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -998,6 +1001,9 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) From f74661edc683d231dc2f3859f5867214f3124bcd Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:33:18 +1000 Subject: [PATCH 08/18] mp: add mode for non comfy weight prioritization non-comfy weights dont get async offload and a few other performance limitations. Load them at top priority accordingly. --- comfy/model_patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57cec2747..0263f133d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -671,7 +671,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -698,7 +698,8 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): From f75765721d008a43581492582b4244efa827ee9d Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:36:09 +1000 Subject: [PATCH 09/18] ops/mp: implement aimdo Implement a model patcher and caster for aimdo. A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens. As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute). Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods. Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive). --- comfy/model_management.py | 56 ++++++++++ comfy/model_patcher.py | 214 ++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 142 ++++++++++++++++++++++--- 3 files changed, 400 insertions(+), 12 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b1bf3bd59..c5a22e04c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -27,8 +27,12 @@ import weakref import gc import os from contextlib import nullcontext +import comfy.utils import comfy.quant_ops +import comfy_aimdo.torch +import comfy_aimdo.model_vbar + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1157,7 +1161,59 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + #FIXME: remove hooks before PR + if hasattr(weight, "comfy_hook"): + dtype = r.dtype + r = weight.comfy_hook(r) + if r.dtype != dtype: + r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key)) + + if signature is not None: + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0263f133d..77e7eec90 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -39,6 +39,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -1397,3 +1398,216 @@ class ModelPatcher: self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False): + if comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + pass + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load; + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + if dirty: + comfy.pinned_memory.unpin_memory(item) + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + return comfy.memory_management.vram_aligned_size(weight) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + weight_size = weight.numel() * weight.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index 2c35b21d0..825f9ae82 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,11 @@ import comfy.float import comfy.rmsnorm import json import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -73,7 +78,108 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + resident = True #If pinned data exists, it always has LowVram already applied + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + pin = None + if signature is not None: + #If we are able to increase our load level (e.g. user reduces resolution or batch number) + #reclaim the pin previously used for offload. + comfy.pinned_memory.unpin_memory(s) + elif not resident: + #prepare a new pin + assert comfy.pinned_memory.get_pin(s) is None + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + + params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + hook_fn = getattr(s, param_key + "_hooks", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + q_layout = None + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None or pin is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + if pin is not None: + xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] + if offload_stream is not None: + #FIXME: if post cast didnt do anything this sync is un-needed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -88,6 +194,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) @@ -108,8 +219,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight = params[0] bias = params[1] - non_blocking = comfy.model_management.device_supports_non_blocking(device) - weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 @@ -146,14 +255,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -668,8 +783,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -679,6 +794,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -697,7 +814,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: From 3597b27515ee827190765596ab39a78e3f03f1ae Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:46:39 +1000 Subject: [PATCH 10/18] models: Use CoreModelPatcher Use CoreModelPatcher for all internal ModelPatcher implementations. This drives conditional use of the aimdo feature, while making sure custom node packs get to keep ModelPatcher unchanged for the moment. --- comfy/audio_encoders/audio_encoders.py | 4 +-- comfy/clip_vision.py | 4 +-- comfy/controlnet.py | 2 +- comfy/ldm/hunyuan_video/upsampler.py | 4 +-- comfy/model_base.py | 4 +-- comfy/sd.py | 39 +++++++++++++++----------- comfy_extras/nodes_model_patch.py | 6 ++-- 7 files changed, 34 insertions(+), 29 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 66f2a9d9c..e4dd4e678 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ class HunyuanVideo15SRModel(): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/model_base.py b/comfy/model_base.py index b9e2abdb9..f4b037bb9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -298,7 +298,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -306,7 +306,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index f83164d23..ca268cbbe 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -128,7 +128,7 @@ class CLIP: logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -288,7 +288,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: return self.cond_stage_model.load_sd(sd) @@ -665,13 +665,6 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) - if device is None: device = model_management.vae_device() self.device = device @@ -682,7 +675,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -1315,7 +1319,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1403,7 +1407,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1446,7 +1451,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1538,13 +1542,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 1355b3c93..a3303ea1a 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -254,9 +254,9 @@ class ModelPatchLoader: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=self.model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: From 64c2541b0545de224e3dda660d41467be02bc4a2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:49:07 +1000 Subject: [PATCH 11/18] execution: add aimdo primary pytorch cache integration We need to general pytorch cache defragmentation on an appropriate level for aimdo. Do in here on the per node basis, which has a reasonable chance of purging stale shapes out of the pytorch caching allocator and saving VRAM without costing too much garbage collector thrash. This looks like a lot of GC but because aimdo never fails from pytorch and saves the pytorch allocator from ever need to defrag out of demand, but it needs a oil change every now and then so we gotta do it. Doing it here also means the pytorch temps are cleared from task manager VRAM usage so user anxiety can go down a little when they see their vram drop back at the end of workflows inline with inference usage (rather than assuming full VRAM leaks). --- comfy/memory_management.py | 6 ++++++ execution.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f8bca5263..88b6da1e3 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,10 @@ import torch from comfy.quant_ops import QuantizedTensor +import comfy_aimdo.torch + +import logging + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) @@ -49,3 +53,5 @@ def interpret_gathered_like(tensors, gathered): dest_views.append(actuals["data"]) return dest_views + +aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator() diff --git a/execution.py b/execution.py index 648f204ec..fe162db26 100644 --- a/execution.py +++ b/execution.py @@ -1,3 +1,4 @@ +import gc import copy import heapq import inspect @@ -9,9 +10,12 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.pinned_memory +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +519,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + torch.cuda.synchronize() + if allocator is not None: + #FIXME: this is probably a little zealous + # Torch code comments says some stuff about not actually freeing tensors on mempool + #context release. Explicitly garbage collect now. + gc.collect() + torch.cuda.empty_cache() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) From 389c3346317b0128b47f9ca1dfe8f2f68884f76a Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:54:49 +1000 Subject: [PATCH 12/18] main: Go live with --fast dynamic_vram Add the optional command line switch --fast dynamic_vram. This is mutually exclusing --high-vram and --gpu-only which contradict aimdos underlying feature. Add appropriate installation warning and a startup message, match the comfy debug level inconfiguring aimdo. Add comfy-aimdo pip requirement. This will safely stub to a nop for unsupported platforms. --- comfy/cli_args.py | 4 ++++ cuda_malloc.py | 7 ++++++- main.py | 35 ++++++++++++++++++++++++++++++++++- requirements.txt | 1 + 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7..63daca861 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -257,3 +258,6 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/cuda_malloc.py b/cuda_malloc.py index 00ee7b633..3c7c8593e 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,6 +87,9 @@ if not args.cuda_malloc: except: pass +if enables_dynamic_vram() and comfy_aimdo.control.lib is not None: + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" if args.disable_cuda_malloc: args.cuda_malloc = False diff --git a/main.py b/main.py index 37b06c1fa..52f11bfff 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import os import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger from app.assets.scanner import seed_assets import itertools @@ -173,6 +173,30 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + +has_aimdo = False + +import comfy_aimdo.control + +if comfy_aimdo.control.lib is not None: + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + if enables_dynamic_vram(): + logging.info("DynamicVRAM support detected and enabled") + has_aimdo = True +else: + if enables_dynamic_vram(): + logging.info("No native comfy-aimdo install detected. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + import comfy.utils import execution @@ -184,6 +208,15 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +if has_aimdo: + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy_aimdo.control.init_vram_guard(comfy.model_management.get_torch_device().index) +else: + comfy.memory_management.aimdo_allocator = None + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) diff --git a/requirements.txt b/requirements.txt index 8650d28ec..950f51c15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.6 +comfy-aimdo>=0.1.0 #non essential dependencies: kornia>=0.7.1 From e2b440b25e7f1eeff4ab14afe90315ec0cb91613 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:37:46 +1000 Subject: [PATCH 13/18] mm: fix sync Sync before deleting anything. --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c5a22e04c..af59592ea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1099,9 +1099,9 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer - torch.cuda.synchronize() torch.cuda.empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) @@ -1115,8 +1115,8 @@ def get_cast_buffer(offload_stream, device, size, ref): def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) - STREAM_CAST_BUFFERS.clear() torch.cuda.synchronize() + STREAM_CAST_BUFFERS.clear() torch.cuda.empty_cache() def get_offload_stream(device): From e2d29926a903540dcff59ac7173dd32cf22151cd Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:38:36 +1000 Subject: [PATCH 14/18] write better tx commentary --- comfy/ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 825f9ae82..24d275216 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -170,8 +170,11 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if pin is not None: xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] + #FIXME: This might be the wrong thing to do. Some reading suggests the DMA engine + #is posted writes and the compute stream could just fire and forget here. That + #would save this sync and some stalling on the offload stream that is better off + #running ahead to the next layer to read. if offload_stream is not None: - #FIXME: if post cast didnt do anything this sync is un-needed offload_stream.wait_stream(comfy.model_management.current_stream(device)) comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) From 2c554a112eabb361ed2e410579d478335ae3e3ba Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:39:00 +1000 Subject: [PATCH 15/18] add missing del on unpin --- comfy/pinned_memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index be303b4f1..923872dac 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -31,4 +31,5 @@ def unpin_memory(module): return 0 size = module._pin.numel() * module._pin.element_size() comfy.model_management.unpin_memory(module._pin) + del module._pin return size From 94a709d813faaa86bbdadbce1bd3b0888d89fe1e Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:40:52 +1000 Subject: [PATCH 16/18] misc cleanup --- comfy/pinned_memory.py | 12 ++++-------- execution.py | 1 - requirements.txt | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 923872dac..dfb5fcfcd 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -14,14 +14,10 @@ def pin_memory(module): #FIXME: This is a RAM cache trigger event params = [ module.weight, module.bias ] size = comfy.memory_management.vram_aligned_size(params) - try: - pin = torch.empty((size,), dtype=torch.uint8) - if comfy.model_management.pin_memory(pin): - module._pin = pin - else: - module.pin_failed = True - return False - except: + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: module.pin_failed = True return False return True diff --git a/execution.py b/execution.py index fe162db26..197e13f38 100644 --- a/execution.py +++ b/execution.py @@ -14,7 +14,6 @@ from contextlib import nullcontext import torch -import comfy.pinned_memory import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method diff --git a/requirements.txt b/requirements.txt index 950f51c15..2574f415b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.6 -comfy-aimdo>=0.1.0 +comfy-aimdo>=0.1.1 #non essential dependencies: kornia>=0.7.1 From 013f132085e915173135efea96aa94923956a8b3 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 20:29:13 +1000 Subject: [PATCH 17/18] ruff --- comfy/memory_management.py | 3 --- comfy/model_patcher.py | 5 ++--- comfy/ops.py | 4 +--- comfy/pinned_memory.py | 1 - comfy/samplers.py | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 88b6da1e3..4169e853c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,10 +1,7 @@ -import torch from comfy.quant_ops import QuantizedTensor import comfy_aimdo.torch -import logging - def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 77e7eec90..b30a9c63d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,7 +24,6 @@ import inspect import logging import math import uuid -import types from typing import Callable, Optional import torch @@ -1381,7 +1380,7 @@ class ModelPatcher: unet_state_dict = self.model.diffusion_model.state_dict() for k, v in unet_state_dict.items(): op_keys = k.rsplit('.', 1) - if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: continue try: op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) @@ -1467,7 +1466,7 @@ class ModelPatcherDynamic(ModelPatcher): #Full load doesn't make sense as we dont actually have any loader capability here and #now. - assert not full_load; + assert not full_load assert device_to == self.load_device diff --git a/comfy/ops.py b/comfy/ops.py index 24d275216..1057cda6b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -128,11 +128,9 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - hook_fn = getattr(s, param_key + "_hooks", None) fns = getattr(s, param_key + "_function", []) orig = x - q_layout = None def to_dequant(tensor, dtype): tensor = tensor.to(dtype=dtype) @@ -218,7 +216,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if cast_buffer is None: offload_stream = comfy.model_management.get_offload_stream(device) cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) - params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) weight = params[0] bias = params[1] diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index dfb5fcfcd..650e27a10 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,5 +1,4 @@ import torch -import logging import comfy.model_management import comfy.memory_management diff --git a/comfy/samplers.py b/comfy/samplers.py index d495ca203..8b9782956 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers From 0eff43261b94887aca3bdfac0b061136eaf15ca6 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 21:13:35 +1000 Subject: [PATCH 18/18] sd: empty cache on tiler fallback This is needed for aimdo where the cache cant self recover from fragmentation. It is however a good thing to do anyway after an OOM so make it unconditional. --- comfy/sd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index ca268cbbe..12f1373c5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -820,6 +820,7 @@ class VAE: do_tile = True if do_tile: + torch.cuda.empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -895,6 +896,7 @@ class VAE: do_tile = True if do_tile: + torch.cuda.empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4