mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 13:50:15 +08:00
ops: Implement prefetching API
Implement an API that allows instrumenting a model with a prefetch queue. Units of work are on the nn.Module level.
This commit is contained in:
parent
c350009236
commit
e279e1f26e
134
comfy/ops.py
134
comfy/ops.py
@ -22,7 +22,6 @@ import comfy.model_management
|
|||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@ -71,6 +70,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_prefetch_all(module, device):
|
||||||
|
if not comfy.model_management.device_supports_non_blocking(device):
|
||||||
|
#Adios! prefetching works against you if you can't get the CPU past it
|
||||||
|
return None
|
||||||
|
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
|
for n, m in module.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if m.weight is not None and m.weight.device != device and not hasattr(m, "weight_prefetch"):
|
||||||
|
if offload_stream is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if offload_stream is None:
|
||||||
|
return None
|
||||||
|
m.weight_prefetch = comfy.model_management.cast_to(m.weight, None, device, non_blocking=True, copy=True, stream=offload_stream)
|
||||||
|
if m.bias is not None and m.bias.device != device and not hasattr(m, "bias_prefetch"):
|
||||||
|
if offload_stream is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if offload_stream is None:
|
||||||
|
return None
|
||||||
|
m.bias_prefetch = comfy.model_management.cast_to(m.bias, None, device, non_blocking=True, copy = True, stream=offload_stream)
|
||||||
|
|
||||||
|
return offload_stream
|
||||||
|
|
||||||
|
|
||||||
|
def uncast_prefetch_all(module):
|
||||||
|
for n, m in module.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if hasattr(m, "weight_prefetch"):
|
||||||
|
delattr(m, "weight_prefetch")
|
||||||
|
if hasattr(m, "bias_prefetch"):
|
||||||
|
delattr(m, "bias_prefetch")
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_queue_pop(queue, device, module):
|
||||||
|
consumed = queue.pop(0)
|
||||||
|
if consumed is not None:
|
||||||
|
offload_stream, m = consumed
|
||||||
|
#Sync the offload stream with compute so when it starts
|
||||||
|
#freeing the prefetches the compute stream has finished
|
||||||
|
if offload_stream is not None:
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
uncast_prefetch_all(m)
|
||||||
|
|
||||||
|
active = queue[0]
|
||||||
|
if active is not None:
|
||||||
|
offload_stream, m = active
|
||||||
|
assert m == module
|
||||||
|
#wait for the prefetch to complete before using the data
|
||||||
|
if offload_stream is not None:
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
prefetch = queue[1]
|
||||||
|
if prefetch is not None:
|
||||||
|
offload_stream = comfy.ops.cast_prefetch_all(prefetch, device)
|
||||||
|
queue[1] = (offload_stream, prefetch)
|
||||||
|
|
||||||
|
|
||||||
|
def make_prefetch_queue(queue):
|
||||||
|
return [None, None] + queue + [None, None]
|
||||||
|
|
||||||
|
|
||||||
|
def move_bias_weight(s, device, offloadable=False):
|
||||||
|
|
||||||
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
|
||||||
|
if offloadable and (
|
||||||
|
s.weight.device != device or (s.bias is not None and s.bias.device != device) or
|
||||||
|
bias_has_function or weight_has_function):
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
else:
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
return weight, bias, offload_stream
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
@ -83,40 +169,30 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
|
||||||
(s.bias is not None and device != s.bias.device)):
|
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
||||||
else:
|
|
||||||
offload_stream = None
|
|
||||||
|
|
||||||
if offload_stream is not None:
|
|
||||||
wf_context = offload_stream
|
|
||||||
else:
|
|
||||||
wf_context = contextlib.nullcontext()
|
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
|
||||||
|
|
||||||
weight_has_function = len(s.weight_function) > 0
|
|
||||||
bias_has_function = len(s.bias_function) > 0
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
|
||||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
if hasattr(s, "weight_prefetch") or hasattr(s, "bias_prefetch"):
|
||||||
|
weight = getattr(s, "weight_prefetch", None)
|
||||||
|
bias = getattr(s, "bias_prefetch", None)
|
||||||
|
offload_stream = None
|
||||||
|
else:
|
||||||
|
weight, bias, offload_stream = move_bias_weight(s, device, offloadable=offloadable)
|
||||||
|
|
||||||
bias = None
|
if weight_has_function:
|
||||||
if s.bias is not None:
|
weight=weight.to(dtype=dtype)
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
for f in s.weight_function:
|
||||||
|
weight = f(weight)
|
||||||
|
|
||||||
if bias_has_function:
|
if s.bias is not None and bias_has_function:
|
||||||
with wf_context:
|
bias=bias.to(dtype=bias_dtype)
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
if weight_has_function or weight.dtype != dtype:
|
weight=weight.to(dtype=dtype)
|
||||||
with wf_context:
|
if bias is not None:
|
||||||
weight = weight.to(dtype=dtype)
|
bias=bias.to(dtype=bias_dtype)
|
||||||
for f in s.weight_function:
|
|
||||||
weight = f(weight)
|
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
if offloadable:
|
if offloadable:
|
||||||
return weight, bias, offload_stream
|
return weight, bias, offload_stream
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user