From 9cfd71e8212fc17a03be3e73f268ce3e25b3305b Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 16 Feb 2026 18:49:56 +1000 Subject: [PATCH 01/10] mm: Use Aimdo raw allocator for cast buffers pytorch manages allocation of growing buffers on streams poorly. Pyt has no windows support for the expandable segments allocator (which is the right tool for this job), while also segmenting the memory by stream such that it can be generally re-used. So kick the problem to aimdo which can just grow a virtual region thats freed per stream. --- comfy/model_management.py | 22 ++++++++++++++++++++-- comfy/ops.py | 9 ++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b39d6080..afda2f086 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -1181,6 +1182,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1214,13 +1219,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer + def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..febde458d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -138,10 +138,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device) + if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) if xfer_dest is None: xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) offload_stream = None From 312f08e44f0cc5f349096041a21a3466bb000aca Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 24 Apr 2026 17:49:08 +1000 Subject: [PATCH 02/10] plan --- comfy/ops.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index febde458d..cd1f24c78 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -102,6 +102,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu bias = s.bias.to(dtype=bias_dtype, copy=True) return weight, bias, (None, None, None) + #plan = [] + #Some sort of loop here like what you did + #for module in comfy_modules: ... + offload_stream = None xfer_dest = None @@ -160,8 +164,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + #attach prefetch info to the module inside the loop .. + + #this sync is conceptually the last thing this function does - after the loop comfy.model_management.sync_stream(device, offload_stream) + + #all compute stuff need to be deferred to the new second phase if cast_dest is not None: for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): @@ -177,6 +186,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu s._v_bias = bias s._v_signature=signature + #factor this our like you did before. def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) @@ -236,7 +246,14 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #that cpu guard goes here. + + #check for a prefetch result here. Something like: + #if not prefetch: + #cast_modules([s], ...) + #this is the phase 2 call like you made before ... + return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): From 132c9f3ac66b227ae3346254c71b270f0b76f9f4 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 24 Apr 2026 19:24:13 +1000 Subject: [PATCH 03/10] ops: move cpu handler up to the caller --- comfy/ops.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index cd1f24c78..6a4189791 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -87,21 +87,6 @@ def materialize_meta_param(s, param_keys): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #vbar doesn't support CPU weights, but some custom nodes have weird paths - #that might switch the layer to the CPU and expect it to work. We have to take - #a clone conservatively as we are mmapped and some SFT files are packed misaligned - #If you are a custom node author reading this, please move your layer to the GPU - #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = None - if s.bias is not None: - bias = s.bias.to(dtype=bias_dtype, copy=True) - return weight, bias, (None, None, None) - #plan = [] #Some sort of loop here like what you did #for module in comfy_modules: ... @@ -247,7 +232,18 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if hasattr(s, "_v"): - #that cpu guard goes here. + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return (weight, bias, (None, None, None)) if offloadable else (weight, bias) #check for a prefetch result here. Something like: #if not prefetch: From 0e93c88c670a0bd5b9efe18e8e8525ae9546f31f Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Apr 2026 22:52:07 +1000 Subject: [PATCH 04/10] ops: split up prefetch from weight prep block prefetching API Split up the casting and weight formating/lora stuff in prep for arbitrary prefetch support. --- comfy/ops.py | 129 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 49 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6a4189791..6baad439f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,27 +86,29 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #plan = [] - #Some sort of loop here like what you did - #for module in comfy_modules: ... - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -118,25 +120,29 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None and cast_buffer is None: + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if len(comfy_modules) == 1: if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: offload_stream = comfy.model_management.get_offload_stream(device) cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device) if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + if cast_buffer is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device) + cast_buffer_offset += dest_size + else: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -149,29 +155,45 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - #attach prefetch info to the module inside the loop .. + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch - #this sync is conceptually the last thing this function does - after the loop - comfy.model_management.sync_stream(device, offload_stream) + return offload_stream - #all compute stuff need to be deferred to the new second phase - if cast_dest is not None: + + +def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + del non_blocking + + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + raise RuntimeError("phase_2 called without a VBAR prefetch state") + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] - #factor this our like you did before. def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) @@ -203,14 +225,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None - - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + bias = None if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -228,6 +249,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): @@ -243,13 +268,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if isinstance(weight, QuantizedTensor): weight = weight.dequantize() bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None - return (weight, bias, (None, None, None)) if offloadable else (weight, bias) + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) - #check for a prefetch result here. Something like: - #if not prefetch: - #cast_modules([s], ...) - #this is the phase 2 call like you made before ... - return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -296,11 +331,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream): From 74261f12f28ee7186f4c13a55cef732b161b0253 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Apr 2026 22:52:07 +1000 Subject: [PATCH 05/10] ops: implement block prefetching API allow a model to construct a prefetch list and operate it for increased async offload. --- comfy/model_management.py | 28 ++++++++++++++++++++++++ comfy/ops.py | 45 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index afda2f086..73c18c139 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.model_vbar import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -1184,6 +1185,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +PREFETCH_QUEUES = [] DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1227,9 +1229,30 @@ def get_aimdo_cast_buffer(offload_stream, device): return cast_buffer + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + + +def cleanup_prefetch_queue(queue): + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT + global PREFETCH_QUEUES LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) @@ -1237,6 +1260,11 @@ def reset_cast_buffers(): if offload_stream is not None: offload_stream.synchronize() synchronize() + + for queue in PREFETCH_QUEUES: + cleanup_prefetch_queue(queue) + PREFETCH_QUEUES = [] + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() diff --git a/comfy/ops.py b/comfy/ops.py index 6baad439f..2e0b45be4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -164,6 +164,51 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin return offload_stream +def cast_prefetch_all(module, device): + prefetch_state = (module, None) + if (comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return (None, prefetch_state) + + comfy_modules = [] + for s in module.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) + return (offload_stream, (module, comfy_modules)) + + +def uncast_prefetch_all(prefetch_state): + _, comfy_modules = prefetch_state + if comfy_modules is not None: + comfy.model_management.cleanup_prefetched_modules(comfy_modules) + + +def prefetch_queue_pop(queue, device, module): + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + uncast_prefetch_all(prefetch_state) + + active = queue[0] + if active is not None: + offload_stream, prefetch_state = active + assert prefetch_state[0] is module + if offload_stream is not None: + comfy.model_management.sync_stream(device, offload_stream) + + prefetch = queue[1] + if prefetch is not None: + queue[1] = cast_prefetch_all(prefetch, device) + + +def make_prefetch_queue(queue): + queue = [None, None] + queue + [None, None] + comfy.model_management.PREFETCH_QUEUES.append(queue) + return queue def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): From 9242551e9b2e2dfdb7311ede306ee7d2a66a5dc2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Apr 2026 22:49:03 +1000 Subject: [PATCH 06/10] ltxv2: Implement block prefetching --- comfy/ldm/lightricks/av_model.py | 5 +++++ comfy/ops.py | 14 ++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..55d25d897 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.ops class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/ops.py b/comfy/ops.py index 2e0b45be4..34b70a1bf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -176,6 +176,7 @@ def cast_prefetch_all(module, device): comfy_modules.append(s) offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) return (offload_stream, (module, comfy_modules)) @@ -193,20 +194,13 @@ def prefetch_queue_pop(queue, device, module): offload_stream.wait_stream(comfy.model_management.current_stream(device)) uncast_prefetch_all(prefetch_state) - active = queue[0] - if active is not None: - offload_stream, prefetch_state = active - assert prefetch_state[0] is module - if offload_stream is not None: - comfy.model_management.sync_stream(device, offload_stream) - - prefetch = queue[1] + prefetch = queue[0] if prefetch is not None: - queue[1] = cast_prefetch_all(prefetch, device) + queue[0] = cast_prefetch_all(prefetch, device) def make_prefetch_queue(queue): - queue = [None, None] + queue + [None, None] + queue = [None] + queue + [None] comfy.model_management.PREFETCH_QUEUES.append(queue) return queue From d73b819eda127d2d2a7c7627748dfca312f66e39 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 28 Apr 2026 19:50:28 +1000 Subject: [PATCH 07/10] Implement lora async offload Implement async offload of loras. --- comfy/ldm/lightricks/av_model.py | 8 +-- comfy/lora.py | 15 +++++ comfy/model_management.py | 28 -------- comfy/model_patcher.py | 13 +++- comfy/model_prefetch.py | 68 +++++++++++++++++++ comfy/ops.py | 108 ++++++++++++++----------------- execution.py | 2 + 7 files changed, 150 insertions(+), 92 deletions(-) create mode 100644 comfy/model_prefetch.py diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 55d25d897..7987c8cf2 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,7 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit -import comfy.ops +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -908,11 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): - comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -985,7 +985,7 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) - comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) return [vx, ax] diff --git a/comfy/lora.py b/comfy/lora.py index 63ee85323..5be361a4a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -467,3 +468,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_management.py b/comfy/model_management.py index 73c18c139..edd73ebfc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,7 +31,6 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.model_vbar import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -1185,7 +1184,6 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) -PREFETCH_QUEUES = [] DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1228,31 +1226,9 @@ def get_aimdo_cast_buffer(offload_stream, device): STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer return cast_buffer - - -def cleanup_prefetched_modules(comfy_modules): - for s in comfy_modules: - prefetch = getattr(s, "_prefetch", None) - if prefetch is None: - continue - if prefetch["signature"] is not None: - comfy_aimdo.model_vbar.vbar_unpin(s._v) - delattr(s, "_prefetch") - - -def cleanup_prefetch_queue(queue): - for entry in queue: - if entry is None or not isinstance(entry, tuple): - continue - _, prefetch_state = entry - comfy_modules = prefetch_state[1] - if comfy_modules is not None: - cleanup_prefetched_modules(comfy_modules) - def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT - global PREFETCH_QUEUES LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) @@ -1261,10 +1237,6 @@ def reset_cast_buffers(): offload_stream.synchronize() synchronize() - for queue in PREFETCH_QUEUES: - cleanup_prefetch_queue(queue) - PREFETCH_QUEUES = [] - STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..982c599b4 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,68 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None: + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + if offload_stream is None: + queue[0] = None + return + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device): + if (comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 34b70a1bf..6693c11c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -92,6 +92,38 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer = None cast_buffer_offset = 0 + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -126,23 +158,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - if offload_stream is None: - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None and cast_buffer is None: - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - if len(comfy_modules) == 1: - if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: - offload_stream = comfy.model_management.get_offload_stream(device) - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: - comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - if cast_buffer is not None: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device) - cast_buffer_offset += dest_size - else: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -155,6 +173,13 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest prefetch["cast_geometry"] = cast_geometry @@ -164,47 +189,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin return offload_stream -def cast_prefetch_all(module, device): - prefetch_state = (module, None) - if (comfy.model_management.is_device_cpu(device) - or not comfy.model_management.device_supports_non_blocking(device)): - return (None, prefetch_state) - - comfy_modules = [] - for s in module.modules(): - if hasattr(s, "_v"): - comfy_modules.append(s) - - offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) - comfy.model_management.sync_stream(device, offload_stream) - return (offload_stream, (module, comfy_modules)) - - -def uncast_prefetch_all(prefetch_state): - _, comfy_modules = prefetch_state - if comfy_modules is not None: - comfy.model_management.cleanup_prefetched_modules(comfy_modules) - - -def prefetch_queue_pop(queue, device, module): - consumed = queue.pop(0) - if consumed is not None: - offload_stream, prefetch_state = consumed - if offload_stream is not None: - offload_stream.wait_stream(comfy.model_management.current_stream(device)) - uncast_prefetch_all(prefetch_state) - - prefetch = queue[0] - if prefetch is not None: - queue[0] = cast_prefetch_all(prefetch, device) - - -def make_prefetch_queue(queue): - queue = [None] + queue + [None] - comfy.model_management.PREFETCH_QUEUES.append(queue) - return queue - - def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): del non_blocking @@ -237,6 +221,9 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -266,8 +253,7 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ update_weight = prefetch["signature"] is not None weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) - bias = None - if s.bias is not None: + if bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) return weight, bias @@ -321,6 +307,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if not prefetched: if getattr(s, "_prefetch")["signature"] is not None: offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() delattr(s, "_prefetch") return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) diff --git a/execution.py b/execution.py index e15eb4bda..de4b14a2b 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: From dfdb4071a637013a3b33fd939c5b77704694f2e1 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 29 Apr 2026 17:33:11 +1000 Subject: [PATCH 08/10] model_base: support non-dynamic disable cases --- comfy/ldm/lightricks/av_model.py | 2 +- comfy/model_base.py | 5 +++++ comfy/model_prefetch.py | 5 +++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 7987c8cf2..3fb87b4a3 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -908,7 +908,7 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): diff --git a/comfy/model_base.py b/comfy/model_base.py index 787ea1145..cc8ba69cb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -210,6 +210,11 @@ class BaseModel(torch.nn.Module): if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + transformer_options = transformer_options.copy() + transformer_options["prefetch_dynamic_vbars"] = ( + self.current_patcher is not None and self.current_patcher.is_dynamic() + ) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index 982c599b4..b705029f6 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -57,8 +57,9 @@ def prefetch_queue_pop(queue, device, module): comfy.model_management.sync_stream(device, offload_stream) queue[0] = (offload_stream, (prefetch, comfy_modules)) -def make_prefetch_queue(queue, device): - if (comfy.model_management.NUM_STREAMS == 0 +def make_prefetch_queue(queue, device, transformer_options): + if (not transformer_options.get("prefetch_dynamic_vbars", False) + or comfy.model_management.NUM_STREAMS == 0 or comfy.model_management.is_device_cpu(device) or not comfy.model_management.device_supports_non_blocking(device)): return None From f3ad2b7f2eecfff2e83119746dd2ba1ffa292473 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 30 Apr 2026 09:14:14 +1000 Subject: [PATCH 09/10] Rabbit --- comfy/model_prefetch.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index b705029f6..0ad35deb5 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -4,7 +4,6 @@ import comfy.ops PREFETCH_QUEUES = [] - def cleanup_prefetched_modules(comfy_modules): for s in comfy_modules: prefetch = getattr(s, "_prefetch", None) @@ -17,12 +16,13 @@ def cleanup_prefetched_modules(comfy_modules): if prefetch["signature"] is not None: comfy_aimdo.model_vbar.vbar_unpin(s._v) delattr(s, "_prefetch") + def cleanup_prefetch_queues(): global PREFETCH_QUEUES for queue in PREFETCH_QUEUES: for entry in queue: - if entry is None: + if entry is None or not isinstance(entry, tuple): continue _, prefetch_state = entry comfy_modules = prefetch_state[1] @@ -30,7 +30,6 @@ def cleanup_prefetch_queues(): cleanup_prefetched_modules(comfy_modules) PREFETCH_QUEUES = [] - def prefetch_queue_pop(queue, device, module): if queue is None: return @@ -51,9 +50,6 @@ def prefetch_queue_pop(queue, device, module): comfy_modules.append(s) offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) - if offload_stream is None: - queue[0] = None - return comfy.model_management.sync_stream(device, offload_stream) queue[0] = (offload_stream, (prefetch, comfy_modules)) From 38d7484166b835403390061db03d4f2feec4552e Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 30 Apr 2026 10:10:36 +1000 Subject: [PATCH 10/10] fix phase2 name --- comfy/ops.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6693c11c5..96db1411c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -189,12 +189,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin return offload_stream -def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - del non_blocking +def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant): prefetch = getattr(s, "_prefetch", None) - if prefetch is None: - raise RuntimeError("phase_2 called without a VBAR prefetch state") if prefetch["resident"]: weight = s._v_weight @@ -302,7 +299,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) comfy.model_management.sync_stream(device, offload_stream) - weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant) if not prefetched: if getattr(s, "_prefetch")["signature"] is not None: