mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
ops: refactor in prep for chunking
This commit is contained in:
parent
0242954aaa
commit
ed15d62a6c
21
comfy/ops.py
21
comfy/ops.py
@ -190,25 +190,25 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
|
||||
def handle_pin_miss(m, source, dest, subset="weights", size=None):
|
||||
pin = None
|
||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
return [ pin ]
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
return
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return None
|
||||
return source
|
||||
return
|
||||
cast_maybe_lowvram_patch(source, dest, offload_stream)
|
||||
|
||||
if pin is None:
|
||||
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||
@ -219,9 +219,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
|
||||
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
|
||||
|
||||
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user