This commit is contained in:
Rattus 2026-04-24 17:49:08 +10:00
parent 9cfd71e821
commit 312f08e44f

View File

@ -102,6 +102,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
#plan = []
#Some sort of loop here like what you did
#for module in comfy_modules: ...
offload_stream = None
xfer_dest = None
@ -160,8 +164,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
#attach prefetch info to the module inside the loop ..
#this sync is conceptually the last thing this function does - after the loop
comfy.model_management.sync_stream(device, offload_stream)
#all compute stuff need to be deferred to the new second phase
if cast_dest is not None:
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
@ -177,6 +186,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
s._v_bias = bias
s._v_signature=signature
#factor this our like you did before.
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
@ -236,7 +246,14 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#that cpu guard goes here.
#check for a prefetch result here. Something like:
#if not prefetch:
#cast_modules([s], ...)
#this is the phase 2 call like you made before ...
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):