diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 0fe2a585c..54a2ef704 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -9,19 +9,16 @@ block at a time and maintains a KV cache across blocks. Reference: https://github.com/thu-ml/Causal-Forcing """ -import math import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.wan.model import ( sinusoidal_embedding_1d, - WAN_CROSSATTENTION_CLASSES, - Head, - MLPProj, repeat_e, + WanModel, + WanAttentionBlock, ) import comfy.ldm.common_dit import comfy.model_management @@ -87,33 +84,18 @@ class CausalWanSelfAttention(nn.Module): return x -class CausalWanAttentionBlock(nn.Module): +class CausalWanAttentionBlock(WanAttentionBlock): """Transformer block with KV-cached self-attention and cross-attention caching.""" def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6, operation_settings={}): - super().__init__() - self.dim = dim - self.ffn_dim = ffn_dim - self.num_heads = num_heads - - ops = operation_settings.get("operations") - device = operation_settings.get("device") - dtype = operation_settings.get("dtype") - - self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) - self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) - self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity() - self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( - dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings) - self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) - self.ffn = nn.Sequential( - ops.Linear(dim, ffn_dim, device=device, dtype=dtype), - nn.GELU(approximate='tanh'), - ops.Linear(ffn_dim, dim, device=device, dtype=dtype)) - - self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype)) + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + self.self_attn = CausalWanSelfAttention( + dim, num_heads, window_size, qk_norm, eps, + operation_settings=operation_settings) def forward(self, x, e, freqs, context, context_img_len=257, kv_cache=None, crossattn_cache=None, transformer_options={}): @@ -150,7 +132,7 @@ class CausalWanAttentionBlock(nn.Module): return x -class CausalWanModel(torch.nn.Module): +class CausalWanModel(WanModel): """ Wan 2.1 diffusion backbone with causal KV-cache support. @@ -178,82 +160,14 @@ class CausalWanModel(torch.nn.Module): device=None, dtype=None, operations=None): - super().__init__() - self.dtype = dtype - operation_settings = {"operations": operations, "device": device, "dtype": dtype} - - self.model_type = model_type - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - - self.patch_embedding = operations.Conv3d( - in_dim, dim, kernel_size=patch_size, stride=patch_size, - device=device, dtype=dtype) - self.text_embedding = nn.Sequential( - operations.Linear(text_dim, dim, device=device, dtype=dtype), - nn.GELU(approximate='tanh'), - operations.Linear(dim, dim, device=device, dtype=dtype)) - self.time_embedding = nn.Sequential( - operations.Linear(freq_dim, dim, device=device, dtype=dtype), - nn.SiLU(), - operations.Linear(dim, dim, device=device, dtype=dtype)) - self.time_projection = nn.Sequential( - nn.SiLU(), - operations.Linear(dim, dim * 6, device=device, dtype=dtype)) - - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' - self.blocks = nn.ModuleList([ - CausalWanAttentionBlock( - cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, - operation_settings=operation_settings) - for _ in range(num_layers) - ]) - - self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) - - d = dim // num_heads - self.rope_embedder = EmbedND( - dim=d, theta=10000.0, - axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) - - if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) - else: - self.img_emb = None - - self.ref_conv = None - - def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None): - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) - - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype) - img_ids[:, :, :, 0] += torch.linspace( - t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype - ).reshape(-1, 1, 1) - img_ids[:, :, :, 1] += torch.linspace( - 0, h_len - 1, steps=h_len, device=device, dtype=dtype - ).reshape(1, -1, 1) - img_ids[:, :, :, 2] += torch.linspace( - 0, w_len - 1, steps=w_len, device=device, dtype=dtype - ).reshape(1, 1, -1) - img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) - return self.rope_embedder(img_ids).movedim(1, 2) + super().__init__( + model_type=model_type, patch_size=patch_size, text_len=text_len, + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, + wan_attn_block_class=CausalWanAttentionBlock, + device=device, dtype=dtype, operations=operations) def forward_block(self, x, timestep, context, start_frame, kv_caches, crossattn_caches, clip_fea=None): @@ -275,11 +189,11 @@ class CausalWanModel(torch.nn.Module): x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) bs, c, t, h, w = x.shape - x = self.patch_embedding(x) + x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) - # Per-frame time embedding → [B, block_frames, 6, dim] + # Per-frame time embedding e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) @@ -311,14 +225,6 @@ class CausalWanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x[:, :, :t, :h, :w] - def unpatchify(self, x, grid_sizes): - c = self.out_dim - b = x.shape[0] - u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) - u = torch.einsum('bfhwpqrc->bcfphqwr', u) - u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) - return u - def init_kv_caches(self, batch_size, max_seq_len, device, dtype): """Create fresh KV caches for all layers.""" caches = [] @@ -365,39 +271,6 @@ class CausalWanModel(torch.nn.Module): clip_fea=clip_fea, ) - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - - t_len = t - if time_dim_concat is not None: - time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) - x = torch.cat([x, time_dim_concat], dim=2) - t_len = x.shape[2] - - x = self.patch_embedding(x) - grid_sizes = x.shape[2:] - x = x.flatten(2).transpose(1, 2) - - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) - - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) - e = e.reshape(timestep.shape[0], -1, e.shape[-1]) - e0 = self.time_projection(e).unflatten(2, (6, self.dim)) - - context = self.text_embedding(context) - - context_img_len = None - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) - context = torch.concat([context_clip, context], dim=1) - context_img_len = clip_fea.shape[-2] - - for block in self.blocks: - x = block(x, e=e0, freqs=freqs, context=context, - context_img_len=context_img_len, - transformer_options=transformer_options) - - x = self.head(x, e) - x = self.unpatchify(x, grid_sizes) - return x[:, :, :t, :h, :w] + return super().forward(x, timestep, context, clip_fea=clip_fea, + time_dim_concat=time_dim_concat, + transformer_options=transformer_options, **kwargs)