mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-29 02:17:52 +08:00
plan
This commit is contained in:
parent
9cfd71e821
commit
312f08e44f
19
comfy/ops.py
19
comfy/ops.py
@ -102,6 +102,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||||
return weight, bias, (None, None, None)
|
return weight, bias, (None, None, None)
|
||||||
|
|
||||||
|
#plan = []
|
||||||
|
#Some sort of loop here like what you did
|
||||||
|
#for module in comfy_modules: ...
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
|
||||||
@ -160,8 +164,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
xfer_source = [ pin ]
|
xfer_source = [ pin ]
|
||||||
#send it over
|
#send it over
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
|
#attach prefetch info to the module inside the loop ..
|
||||||
|
|
||||||
|
#this sync is conceptually the last thing this function does - after the loop
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
|
||||||
|
#all compute stuff need to be deferred to the new second phase
|
||||||
if cast_dest is not None:
|
if cast_dest is not None:
|
||||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||||
@ -177,6 +186,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
s._v_bias = bias
|
s._v_bias = bias
|
||||||
s._v_signature=signature
|
s._v_signature=signature
|
||||||
|
|
||||||
|
#factor this our like you did before.
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
fns = getattr(s, param_key + "_function", [])
|
fns = getattr(s, param_key + "_function", [])
|
||||||
@ -236,7 +246,14 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
|
||||||
|
#that cpu guard goes here.
|
||||||
|
|
||||||
|
#check for a prefetch result here. Something like:
|
||||||
|
#if not prefetch:
|
||||||
|
#cast_modules([s], ...)
|
||||||
|
#this is the phase 2 call like you made before ...
|
||||||
|
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user