From ed15d62a6c6f7ed645e26b5917a242922df1c5a0 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 15 May 2026 22:26:17 +1000 Subject: [PATCH] ops: refactor in prep for chunking --- comfy/ops.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index be744a030..4b436f4a7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -190,25 +190,25 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin else: comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - def handle_pin_miss(m, source, dest, subset="weights", size=None): - pin = None + def handle_pin(m, pin, source, dest, subset="weights", size=None): + if pin is not None: + cast_maybe_lowvram_patch([pin], dest, offload_stream) + return if signature is None: comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) if pin is not None: cast_maybe_lowvram_patch(source, pin, None) - return [ pin ] + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + return if pin is None: pin_offset = get_stream_pin_buffer_offset(size) if pin_offset is not None: stream_pin_queue.append((source, pin_offset, size, dest)) - return None - return source + return + cast_maybe_lowvram_patch(source, dest, offload_stream) - if pin is None: - xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size) - - cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) for param_key in ("weight", "bias"): lowvram_source = getattr(s, param_key + "_lowvram_function", None) @@ -219,9 +219,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") - lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ] + handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) - cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream) prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest