mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction
This commit is contained in:
parent
831c3cf05e
commit
ee01002e63
@ -6,9 +6,10 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
class FlipFlopModule(torch.nn.Module):
|
class FlipFlopModule(torch.nn.Module):
|
||||||
def __init__(self, block_types: tuple[str, ...]):
|
def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block_types = block_types
|
self.block_types = block_types
|
||||||
|
self.enable_flipflop = enable_flipflop
|
||||||
self.flipflop: dict[str, FlipFlopHolder] = {}
|
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||||
|
|
||||||
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):
|
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.flipflop_transformer import FlipFlopModule
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module):
|
|||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
class WanModel(torch.nn.Module):
|
class WanModel(FlipFlopModule):
|
||||||
r"""
|
r"""
|
||||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
"""
|
"""
|
||||||
@ -412,6 +413,7 @@ class WanModel(torch.nn.Module):
|
|||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
|
enable_flipflop=True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initialize the diffusion model backbone.
|
Initialize the diffusion model backbone.
|
||||||
@ -449,7 +451,7 @@ class WanModel(torch.nn.Module):
|
|||||||
Epsilon value for normalization layers
|
Epsilon value for normalization layers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
@ -506,6 +508,18 @@ class WanModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.ref_conv = None
|
self.ref_conv = None
|
||||||
|
|
||||||
|
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -567,16 +581,8 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
# execute blocks
|
||||||
if ("double_block", i) in blocks_replace:
|
x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
|
||||||
return out
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
|
||||||
x = out["img"]
|
|
||||||
else:
|
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@ -688,7 +694,7 @@ class VaceWanModel(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
# Vace
|
# Vace
|
||||||
@ -808,7 +814,7 @@ class CameraWanModel(WanModel):
|
|||||||
else:
|
else:
|
||||||
model_type = 't2v'
|
model_type = 't2v'
|
||||||
|
|
||||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||||
@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||||
|
|
||||||
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||||
|
|
||||||
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
|||||||
@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
|
||||||
|
|
||||||
self.pose_patch_embedding = operations.Conv3d(
|
self.pose_patch_embedding = operations.Conv3d(
|
||||||
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
import time # TODO remove
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
@ -577,7 +577,7 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
|
||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -597,18 +597,22 @@ class ModelPatcher:
|
|||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
if set_func is None:
|
if set_func is None:
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
|
if device_final is not None:
|
||||||
|
out_weight = out_weight.to(device_final)
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
|
if device_final is not None:
|
||||||
|
out_weight = out_weight.to(device_final)
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def supports_flipflop(self):
|
def supports_flipflop(self):
|
||||||
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
||||||
if not hasattr(self.model, "diffusion_model"):
|
if not hasattr(self.model, "diffusion_model"):
|
||||||
return False
|
return False
|
||||||
if not hasattr(self.model.diffusion_model, "flipflop"):
|
if not getattr(self.model.diffusion_model, "enable_flipflop", False):
|
||||||
return False
|
return False
|
||||||
if not comfy.model_management.is_nvidia():
|
if not comfy.model_management.is_nvidia():
|
||||||
return False
|
return False
|
||||||
@ -797,6 +801,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
# handle flipflop
|
# handle flipflop
|
||||||
if len(load_flipflop) > 0:
|
if len(load_flipflop) > 0:
|
||||||
|
start_time = time.perf_counter()
|
||||||
load_flipflop.sort(reverse=True)
|
load_flipflop.sort(reverse=True)
|
||||||
for x in load_flipflop:
|
for x in load_flipflop:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -806,10 +811,15 @@ class ModelPatcher:
|
|||||||
if m.comfy_patched_weights == True:
|
if m.comfy_patched_weights == True:
|
||||||
continue
|
continue
|
||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device)
|
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
|
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
|
||||||
|
start_time = time.perf_counter()
|
||||||
self.init_flipflop_block_copies()
|
self.init_flipflop_block_copies()
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
|
||||||
|
|
||||||
if lowvram_counter > 0 or flipflop_counter > 0:
|
if lowvram_counter > 0 or flipflop_counter > 0:
|
||||||
if flipflop_counter > 0:
|
if flipflop_counter > 0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user