Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction

This commit is contained in:
Jedrzej Kosinski 2025-10-02 22:02:50 -07:00
parent 831c3cf05e
commit ee01002e63
4 changed files with 39 additions and 22 deletions

View File

@ -6,9 +6,10 @@ import comfy.model_management
class FlipFlopModule(torch.nn.Module): class FlipFlopModule(torch.nn.Module):
def __init__(self, block_types: tuple[str, ...]): def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
super().__init__() super().__init__()
self.block_types = block_types self.block_types = block_types
self.enable_flipflop = enable_flipflop
self.flipflop: dict[str, FlipFlopHolder] = {} self.flipflop: dict[str, FlipFlopHolder] = {}
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit import comfy.ldm.common_dit
@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module):
return clip_extra_context_tokens return clip_extra_context_tokens
class WanModel(torch.nn.Module): class WanModel(FlipFlopModule):
r""" r"""
Wan diffusion backbone supporting both text-to-video and image-to-video. Wan diffusion backbone supporting both text-to-video and image-to-video.
""" """
@ -412,6 +413,7 @@ class WanModel(torch.nn.Module):
device=None, device=None,
dtype=None, dtype=None,
operations=None, operations=None,
enable_flipflop=True,
): ):
r""" r"""
Initialize the diffusion model backbone. Initialize the diffusion model backbone.
@ -449,7 +451,7 @@ class WanModel(torch.nn.Module):
Epsilon value for normalization layers Epsilon value for normalization layers
""" """
super().__init__() super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
self.dtype = dtype self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
@ -506,6 +508,18 @@ class WanModel(torch.nn.Module):
else: else:
self.ref_conv = None self.ref_conv = None
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
return x
def forward_orig( def forward_orig(
self, self,
x, x,
@ -567,16 +581,8 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks): # execute blocks
if ("double_block", i) in blocks_replace: x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head # head
x = self.head(x, e) x = self.head(x, e)
@ -688,7 +694,7 @@ class VaceWanModel(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
# Vace # Vace
@ -808,7 +814,7 @@ class CameraWanModel(WanModel):
else: else:
model_type = 't2v' model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)

View File

@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.pose_patch_embedding = operations.Conv3d( self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype

View File

@ -25,7 +25,7 @@ import logging
import math import math
import uuid import uuid
from typing import Callable, Optional from typing import Callable, Optional
import time # TODO remove
import torch import torch
import comfy.float import comfy.float
@ -577,7 +577,7 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False): def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
if key not in self.patches: if key not in self.patches:
return return
@ -597,18 +597,22 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None: if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if device_final is not None:
out_weight = out_weight.to(device_final)
if inplace_update: if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight) comfy.utils.copy_to_param(self.model, key, out_weight)
else: else:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
else: else:
if device_final is not None:
out_weight = out_weight.to(device_final)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def supports_flipflop(self): def supports_flipflop(self):
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
if not hasattr(self.model, "diffusion_model"): if not hasattr(self.model, "diffusion_model"):
return False return False
if not hasattr(self.model.diffusion_model, "flipflop"): if not getattr(self.model.diffusion_model, "enable_flipflop", False):
return False return False
if not comfy.model_management.is_nvidia(): if not comfy.model_management.is_nvidia():
return False return False
@ -797,6 +801,7 @@ class ModelPatcher:
# handle flipflop # handle flipflop
if len(load_flipflop) > 0: if len(load_flipflop) > 0:
start_time = time.perf_counter()
load_flipflop.sort(reverse=True) load_flipflop.sort(reverse=True)
for x in load_flipflop: for x in load_flipflop:
n = x[1] n = x[1]
@ -806,10 +811,15 @@ class ModelPatcher:
if m.comfy_patched_weights == True: if m.comfy_patched_weights == True:
continue continue
for param in params: for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
end_time = time.perf_counter()
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
start_time = time.perf_counter()
self.init_flipflop_block_copies() self.init_flipflop_block_copies()
end_time = time.perf_counter()
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
if lowvram_counter > 0 or flipflop_counter > 0: if lowvram_counter > 0 or flipflop_counter > 0:
if flipflop_counter > 0: if flipflop_counter > 0: