This commit is contained in:
rattus 2025-11-23 13:06:45 -08:00 committed by GitHub
commit 342dd2e9a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 115 additions and 29 deletions

View File

@ -391,6 +391,9 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
if ref_latents is not None:
h = 0
w = 0
@ -440,6 +443,7 @@ class QwenImageTransformer2DModel(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -471,6 +475,8 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -538,6 +538,8 @@ class WanModel(torch.nn.Module):
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks))
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
@ -569,6 +571,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -578,6 +581,7 @@ class WanModel(torch.nn.Module):
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
# head
x = self.head(x, e)

View File

@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
@ -72,6 +71,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_prefetch_all(module, device):
if not comfy.model_management.device_supports_non_blocking(device):
#Adios! prefetching works against you if you can't get the CPU past it
return None
offload_stream = None
for n, m in module.named_modules():
if hasattr(m, "comfy_cast_weights"):
if m.weight is not None and m.weight.device != device and not hasattr(m, "weight_prefetch"):
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None:
return None
m.weight_prefetch = comfy.model_management.cast_to(m.weight, None, device, non_blocking=True, copy=True, stream=offload_stream)
if m.bias is not None and m.bias.device != device and not hasattr(m, "bias_prefetch"):
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None:
return None
m.bias_prefetch = comfy.model_management.cast_to(m.bias, None, device, non_blocking=True, copy = True, stream=offload_stream)
return offload_stream
def uncast_prefetch_all(module):
for n, m in module.named_modules():
if hasattr(m, "comfy_cast_weights"):
if hasattr(m, "weight_prefetch"):
delattr(m, "weight_prefetch")
if hasattr(m, "bias_prefetch"):
delattr(m, "bias_prefetch")
def prefetch_queue_pop(queue, device, module):
consumed = queue.pop(0)
if consumed is not None:
offload_stream, m = consumed
#Sync the offload stream with compute so when it starts
#freeing the prefetches the compute stream has finished
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
uncast_prefetch_all(m)
active = queue[0]
if active is not None:
offload_stream, m = active
assert m == module
#wait for the prefetch to complete before using the data
if offload_stream is not None:
comfy.model_management.sync_stream(device, offload_stream)
prefetch = queue[1]
if prefetch is not None:
offload_stream = comfy.ops.cast_prefetch_all(prefetch, device)
queue[1] = (offload_stream, prefetch)
def make_prefetch_queue(queue):
return [None, None] + queue + [None, None]
def move_bias_weight(s, device, offloadable=False):
bias_has_function = len(s.bias_function) > 0
weight_has_function = len(s.weight_function) > 0
if offloadable and (
s.weight.device != device or (s.bias is not None and s.bias.device != device) or
bias_has_function or weight_has_function):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias, offload_stream
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
@ -87,40 +173,30 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight_has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
if hasattr(s, "weight_prefetch") or hasattr(s, "bias_prefetch"):
weight = getattr(s, "weight_prefetch", None)
bias = getattr(s, "bias_prefetch", None)
offload_stream = None
else:
weight, bias, offload_stream = move_bias_weight(s, device, offloadable=offloadable)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if weight_has_function:
weight=weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
if s.bias is not None and bias_has_function:
bias=bias.to(dtype=bias_dtype)
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)
weight=weight.to(dtype=dtype)
if bias is not None:
bias=bias.to(dtype=bias_dtype)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, offload_stream
else: