Compare commits

...

2 Commits

Author SHA1 Message Date
Rattus
38d7484166 fix phase2 name 2026-04-30 10:10:36 +10:00
Rattus
f3ad2b7f2e Rabbit 2026-04-30 09:15:27 +10:00
2 changed files with 4 additions and 11 deletions

View File

@ -4,7 +4,6 @@ import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
@ -17,12 +16,13 @@ def cleanup_prefetched_modules(comfy_modules):
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
@ -30,7 +30,6 @@ def cleanup_prefetch_queues():
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
@ -51,9 +50,6 @@ def prefetch_queue_pop(queue, device, module):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if offload_stream is None:
queue[0] = None
return
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))

View File

@ -189,12 +189,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
return offload_stream
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
del non_blocking
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
raise RuntimeError("phase_2 called without a VBAR prefetch state")
if prefetch["resident"]:
weight = s._v_weight
@ -302,7 +299,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None: