ops: refactor in prep for chunking

This commit is contained in:
Rattus 2026-05-15 22:26:17 +10:00
parent 0242954aaa
commit ed15d62a6c

View File

@ -190,25 +190,25 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
def handle_pin_miss(m, source, dest, subset="weights", size=None):
pin = None
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None:
cast_maybe_lowvram_patch([pin], dest, offload_stream)
return
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
cast_maybe_lowvram_patch(source, pin, None)
return [ pin ]
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return None
return source
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
if pin is None:
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
for param_key in ("weight", "bias"):
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
@ -219,9 +219,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest