From ee01002e6377531506058dfae9915ac051694feb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 22:02:50 -0700 Subject: [PATCH] Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction --- comfy/ldm/flipflop_transformer.py | 3 ++- comfy/ldm/wan/model.py | 38 ++++++++++++++++++------------- comfy/ldm/wan/model_animate.py | 2 +- comfy/model_patcher.py | 18 +++++++++++---- 4 files changed, 39 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 7bbb08208..9e9c28468 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -6,9 +6,10 @@ import comfy.model_management class FlipFlopModule(torch.nn.Module): - def __init__(self, block_types: tuple[str, ...]): + def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True): super().__init__() self.block_types = block_types + self.enable_flipflop = enable_flipflop self.flipflop: dict[str, FlipFlopHolder] = {} def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0dc650ced..7f8ca65d7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -7,6 +7,7 @@ import torch.nn as nn from einops import rearrange from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flipflop_transformer import FlipFlopModule from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit @@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module): return clip_extra_context_tokens -class WanModel(torch.nn.Module): +class WanModel(FlipFlopModule): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ @@ -412,6 +413,7 @@ class WanModel(torch.nn.Module): device=None, dtype=None, operations=None, + enable_flipflop=True, ): r""" Initialize the diffusion model backbone. @@ -449,7 +451,7 @@ class WanModel(torch.nn.Module): Epsilon value for normalization layers """ - super().__init__() + super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop) self.dtype = dtype operation_settings = {"operations": operations, "device": device, "dtype": dtype} @@ -506,6 +508,18 @@ class WanModel(torch.nn.Module): else: self.ref_conv = None + def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + return x + def forward_orig( self, x, @@ -567,16 +581,8 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + # execute blocks + x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options) # head x = self.head(x, e) @@ -688,7 +694,7 @@ class VaceWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} # Vace @@ -808,7 +814,7 @@ class CameraWanModel(WanModel): else: model_type = 't2v' - super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) @@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) @@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..2c09420c5 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -426,7 +426,7 @@ class AnimateWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.pose_patch_embedding = operations.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 142b810b3..08055b65c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -25,7 +25,7 @@ import logging import math import uuid from typing import Callable, Optional - +import time # TODO remove import torch import comfy.float @@ -577,7 +577,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None): if key not in self.patches: return @@ -597,18 +597,22 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + if device_final is not None: + out_weight = out_weight.to(device_final) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: + if device_final is not None: + out_weight = out_weight.to(device_final) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def supports_flipflop(self): # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM if not hasattr(self.model, "diffusion_model"): return False - if not hasattr(self.model.diffusion_model, "flipflop"): + if not getattr(self.model.diffusion_model, "enable_flipflop", False): return False if not comfy.model_management.is_nvidia(): return False @@ -797,6 +801,7 @@ class ModelPatcher: # handle flipflop if len(load_flipflop) > 0: + start_time = time.perf_counter() load_flipflop.sort(reverse=True) for x in load_flipflop: n = x[1] @@ -806,10 +811,15 @@ class ModelPatcher: if m.comfy_patched_weights == True: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device) logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) + end_time = time.perf_counter() + logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds") + start_time = time.perf_counter() self.init_flipflop_block_copies() + end_time = time.perf_counter() + logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds") if lowvram_counter > 0 or flipflop_counter > 0: if flipflop_counter > 0: