fix phase2 name

This commit is contained in:
Rattus 2026-04-30 10:10:36 +10:00
parent f3ad2b7f2e
commit 38d7484166

View File

@ -189,12 +189,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
return offload_stream
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
del non_blocking
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
raise RuntimeError("phase_2 called without a VBAR prefetch state")
if prefetch["resident"]:
weight = s._v_weight
@ -302,7 +299,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None: