diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 7bbb08208..9e9c28468 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -6,9 +6,10 @@ import comfy.model_management class FlipFlopModule(torch.nn.Module): - def __init__(self, block_types: tuple[str, ...]): + def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True): super().__init__() self.block_types = block_types + self.enable_flipflop = enable_flipflop self.flipflop: dict[str, FlipFlopHolder] = {} def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0dc650ced..7f8ca65d7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -7,6 +7,7 @@ import torch.nn as nn from einops import rearrange from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flipflop_transformer import FlipFlopModule from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit @@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module): return clip_extra_context_tokens -class WanModel(torch.nn.Module): +class WanModel(FlipFlopModule): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ @@ -412,6 +413,7 @@ class WanModel(torch.nn.Module): device=None, dtype=None, operations=None, + enable_flipflop=True, ): r""" Initialize the diffusion model backbone. @@ -449,7 +451,7 @@ class WanModel(torch.nn.Module): Epsilon value for normalization layers """ - super().__init__() + super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop) self.dtype = dtype operation_settings = {"operations": operations, "device": device, "dtype": dtype} @@ -506,6 +508,18 @@ class WanModel(torch.nn.Module): else: self.ref_conv = None + def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + return x + def forward_orig( self, x, @@ -567,16 +581,8 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + # execute blocks + x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options) # head x = self.head(x, e) @@ -688,7 +694,7 @@ class VaceWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} # Vace @@ -808,7 +814,7 @@ class CameraWanModel(WanModel): else: model_type = 't2v' - super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) @@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) @@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..2c09420c5 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -426,7 +426,7 @@ class AnimateWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.pose_patch_embedding = operations.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 142b810b3..08055b65c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -25,7 +25,7 @@ import logging import math import uuid from typing import Callable, Optional - +import time # TODO remove import torch import comfy.float @@ -577,7 +577,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None): if key not in self.patches: return @@ -597,18 +597,22 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + if device_final is not None: + out_weight = out_weight.to(device_final) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: + if device_final is not None: + out_weight = out_weight.to(device_final) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def supports_flipflop(self): # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM if not hasattr(self.model, "diffusion_model"): return False - if not hasattr(self.model.diffusion_model, "flipflop"): + if not getattr(self.model.diffusion_model, "enable_flipflop", False): return False if not comfy.model_management.is_nvidia(): return False @@ -797,6 +801,7 @@ class ModelPatcher: # handle flipflop if len(load_flipflop) > 0: + start_time = time.perf_counter() load_flipflop.sort(reverse=True) for x in load_flipflop: n = x[1] @@ -806,10 +811,15 @@ class ModelPatcher: if m.comfy_patched_weights == True: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device) logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) + end_time = time.perf_counter() + logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds") + start_time = time.perf_counter() self.init_flipflop_block_copies() + end_time = time.perf_counter() + logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds") if lowvram_counter > 0 or flipflop_counter > 0: if flipflop_counter > 0: