Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction

This commit is contained in:
Jedrzej Kosinski 2025-10-02 22:02:50 -07:00
parent 831c3cf05e
commit ee01002e63
4 changed files with 39 additions and 22 deletions

View File

@ -6,9 +6,10 @@ import comfy.model_management
class FlipFlopModule(torch.nn.Module):
def __init__(self, block_types: tuple[str, ...]):
def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
super().__init__()
self.block_types = block_types
self.enable_flipflop = enable_flipflop
self.flipflop: dict[str, FlipFlopHolder] = {}
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module):
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
class WanModel(FlipFlopModule):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
@ -412,6 +413,7 @@ class WanModel(torch.nn.Module):
device=None,
dtype=None,
operations=None,
enable_flipflop=True,
):
r"""
Initialize the diffusion model backbone.
@ -449,7 +451,7 @@ class WanModel(torch.nn.Module):
Epsilon value for normalization layers
"""
super().__init__()
super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
@ -506,6 +508,18 @@ class WanModel(torch.nn.Module):
else:
self.ref_conv = None
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
return x
def forward_orig(
self,
x,
@ -567,16 +581,8 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# execute blocks
x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
# head
x = self.head(x, e)
@ -688,7 +694,7 @@ class VaceWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
# Vace
@ -808,7 +814,7 @@ class CameraWanModel(WanModel):
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)

View File

@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype

View File

@ -25,7 +25,7 @@ import logging
import math
import uuid
from typing import Callable, Optional
import time # TODO remove
import torch
import comfy.float
@ -577,7 +577,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
if key not in self.patches:
return
@ -597,18 +597,22 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if device_final is not None:
out_weight = out_weight.to(device_final)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
if device_final is not None:
out_weight = out_weight.to(device_final)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def supports_flipflop(self):
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
if not hasattr(self.model, "diffusion_model"):
return False
if not hasattr(self.model.diffusion_model, "flipflop"):
if not getattr(self.model.diffusion_model, "enable_flipflop", False):
return False
if not comfy.model_management.is_nvidia():
return False
@ -797,6 +801,7 @@ class ModelPatcher:
# handle flipflop
if len(load_flipflop) > 0:
start_time = time.perf_counter()
load_flipflop.sort(reverse=True)
for x in load_flipflop:
n = x[1]
@ -806,10 +811,15 @@ class ModelPatcher:
if m.comfy_patched_weights == True:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device)
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
end_time = time.perf_counter()
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
start_time = time.perf_counter()
self.init_flipflop_block_copies()
end_time = time.perf_counter()
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
if lowvram_counter > 0 or flipflop_counter > 0:
if flipflop_counter > 0: