diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 427ea19c1..4579e2e38 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -391,6 +391,9 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None) + if ref_latents is not None: h = 0 w = 0 @@ -440,6 +443,7 @@ class QwenImageTransformer2DModel(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -471,6 +475,8 @@ class QwenImageTransformer2DModel(nn.Module): if add is not None: hidden_states[:, :add.shape[1]] += add + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a9d5e10d9..fa2e90c9f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -538,6 +538,8 @@ class WanModel(torch.nn.Module): List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ + prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks)) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] @@ -569,6 +571,7 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -578,6 +581,7 @@ class WanModel(torch.nn.Module): x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block) # head x = self.head(x, e) diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..0b8dfb595 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,7 +22,6 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib def run_every_op(): if torch.compiler.is_compiling(): @@ -72,6 +71,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) +def cast_prefetch_all(module, device): + if not comfy.model_management.device_supports_non_blocking(device): + #Adios! prefetching works against you if you can't get the CPU past it + return None + + offload_stream = None + + for n, m in module.named_modules(): + if hasattr(m, "comfy_cast_weights"): + if m.weight is not None and m.weight.device != device and not hasattr(m, "weight_prefetch"): + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None: + return None + m.weight_prefetch = comfy.model_management.cast_to(m.weight, None, device, non_blocking=True, copy=True, stream=offload_stream) + if m.bias is not None and m.bias.device != device and not hasattr(m, "bias_prefetch"): + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None: + return None + m.bias_prefetch = comfy.model_management.cast_to(m.bias, None, device, non_blocking=True, copy = True, stream=offload_stream) + + return offload_stream + + +def uncast_prefetch_all(module): + for n, m in module.named_modules(): + if hasattr(m, "comfy_cast_weights"): + if hasattr(m, "weight_prefetch"): + delattr(m, "weight_prefetch") + if hasattr(m, "bias_prefetch"): + delattr(m, "bias_prefetch") + + +def prefetch_queue_pop(queue, device, module): + consumed = queue.pop(0) + if consumed is not None: + offload_stream, m = consumed + #Sync the offload stream with compute so when it starts + #freeing the prefetches the compute stream has finished + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + uncast_prefetch_all(m) + + active = queue[0] + if active is not None: + offload_stream, m = active + assert m == module + #wait for the prefetch to complete before using the data + if offload_stream is not None: + comfy.model_management.sync_stream(device, offload_stream) + + prefetch = queue[1] + if prefetch is not None: + offload_stream = comfy.ops.cast_prefetch_all(prefetch, device) + queue[1] = (offload_stream, prefetch) + + +def make_prefetch_queue(queue): + return [None, None] + queue + [None, None] + + +def move_bias_weight(s, device, offloadable=False): + + bias_has_function = len(s.bias_function) > 0 + weight_has_function = len(s.weight_function) > 0 + + if offloadable and ( + s.weight.device != device or (s.bias is not None and s.bias.device != device) or + bias_has_function or weight_has_function): + offload_stream = comfy.model_management.get_offload_stream(device) + else: + offload_stream = None + + bias = None + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + + if s.bias is not None: + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + + comfy.model_management.sync_stream(device, offload_stream) + + return weight, bias, offload_stream + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This @@ -87,40 +173,30 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device - if offloadable and (device != s.weight.device or - (s.bias is not None and device != s.bias.device)): - offload_stream = comfy.model_management.get_offload_stream(device) - else: - offload_stream = None - - if offload_stream is not None: - wf_context = offload_stream - else: - wf_context = contextlib.nullcontext() - - non_blocking = comfy.model_management.device_supports_non_blocking(device) - - weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 + weight_has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + if hasattr(s, "weight_prefetch") or hasattr(s, "bias_prefetch"): + weight = getattr(s, "weight_prefetch", None) + bias = getattr(s, "bias_prefetch", None) + offload_stream = None + else: + weight, bias, offload_stream = move_bias_weight(s, device, offloadable=offloadable) - bias = None - if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + if weight_has_function: + weight=weight.to(dtype=dtype) + for f in s.weight_function: + weight = f(weight) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + if s.bias is not None and bias_has_function: + bias=bias.to(dtype=bias_dtype) + for f in s.bias_function: + bias = f(bias) - if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - for f in s.weight_function: - weight = f(weight) + weight=weight.to(dtype=dtype) + if bias is not None: + bias=bias.to(dtype=bias_dtype) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: return weight, bias, offload_stream else: