mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-19 05:27:24 +08:00
plan
This commit is contained in:
parent
9cfd71e821
commit
312f08e44f
19
comfy/ops.py
19
comfy/ops.py
@ -102,6 +102,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
#plan = []
|
||||
#Some sort of loop here like what you did
|
||||
#for module in comfy_modules: ...
|
||||
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@ -160,8 +164,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
#attach prefetch info to the module inside the loop ..
|
||||
|
||||
#this sync is conceptually the last thing this function does - after the loop
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
|
||||
#all compute stuff need to be deferred to the new second phase
|
||||
if cast_dest is not None:
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
@ -177,6 +186,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
#factor this our like you did before.
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
@ -236,7 +246,14 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#that cpu guard goes here.
|
||||
|
||||
#check for a prefetch result here. Something like:
|
||||
#if not prefetch:
|
||||
#cast_modules([s], ...)
|
||||
#this is the phase 2 call like you made before ...
|
||||
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user