diff --git a/comfy/ops.py b/comfy/ops.py index febde458d..cd1f24c78 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -102,6 +102,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu bias = s.bias.to(dtype=bias_dtype, copy=True) return weight, bias, (None, None, None) + #plan = [] + #Some sort of loop here like what you did + #for module in comfy_modules: ... + offload_stream = None xfer_dest = None @@ -160,8 +164,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + #attach prefetch info to the module inside the loop .. + + #this sync is conceptually the last thing this function does - after the loop comfy.model_management.sync_stream(device, offload_stream) + + #all compute stuff need to be deferred to the new second phase if cast_dest is not None: for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): @@ -177,6 +186,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu s._v_bias = bias s._v_signature=signature + #factor this our like you did before. def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) @@ -236,7 +246,14 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #that cpu guard goes here. + + #check for a prefetch result here. Something like: + #if not prefetch: + #cast_modules([s], ...) + #this is the phase 2 call like you made before ... + return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)):