Refactor CausalWanModel to inherit from WanModel.

This commit is contained in:
Talmaj Marinc 2026-03-25 17:53:59 +01:00
parent 3a9547192e
commit e649a3bc72

View File

@ -9,19 +9,16 @@ block at a time and maintains a KV cache across blocks.
Reference: https://github.com/thu-ml/Causal-Forcing Reference: https://github.com/thu-ml/Causal-Forcing
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.wan.model import ( from comfy.ldm.wan.model import (
sinusoidal_embedding_1d, sinusoidal_embedding_1d,
WAN_CROSSATTENTION_CLASSES,
Head,
MLPProj,
repeat_e, repeat_e,
WanModel,
WanAttentionBlock,
) )
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_management import comfy.model_management
@ -87,33 +84,18 @@ class CausalWanSelfAttention(nn.Module):
return x return x
class CausalWanAttentionBlock(nn.Module): class CausalWanAttentionBlock(WanAttentionBlock):
"""Transformer block with KV-cached self-attention and cross-attention caching.""" """Transformer block with KV-cached self-attention and cross-attention caching."""
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
eps=1e-6, operation_settings={}): eps=1e-6, operation_settings={}):
super().__init__() super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
self.dim = dim window_size, qk_norm, cross_attn_norm, eps,
self.ffn_dim = ffn_dim operation_settings=operation_settings)
self.num_heads = num_heads self.self_attn = CausalWanSelfAttention(
dim, num_heads, window_size, qk_norm, eps,
ops = operation_settings.get("operations") operation_settings=operation_settings)
device = operation_settings.get("device")
dtype = operation_settings.get("dtype")
self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings)
self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
self.ffn = nn.Sequential(
ops.Linear(dim, ffn_dim, device=device, dtype=dtype),
nn.GELU(approximate='tanh'),
ops.Linear(ffn_dim, dim, device=device, dtype=dtype))
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype))
def forward(self, x, e, freqs, context, context_img_len=257, def forward(self, x, e, freqs, context, context_img_len=257,
kv_cache=None, crossattn_cache=None, transformer_options={}): kv_cache=None, crossattn_cache=None, transformer_options={}):
@ -150,7 +132,7 @@ class CausalWanAttentionBlock(nn.Module):
return x return x
class CausalWanModel(torch.nn.Module): class CausalWanModel(WanModel):
""" """
Wan 2.1 diffusion backbone with causal KV-cache support. Wan 2.1 diffusion backbone with causal KV-cache support.
@ -178,82 +160,14 @@ class CausalWanModel(torch.nn.Module):
device=None, device=None,
dtype=None, dtype=None,
operations=None): operations=None):
super().__init__() super().__init__(
self.dtype = dtype model_type=model_type, patch_size=patch_size, text_len=text_len,
operation_settings = {"operations": operations, "device": device, "dtype": dtype} in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
self.model_type = model_type num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
self.patch_size = patch_size cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
self.text_len = text_len wan_attn_block_class=CausalWanAttentionBlock,
self.in_dim = in_dim device=device, dtype=dtype, operations=operations)
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.patch_embedding = operations.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size,
device=device, dtype=dtype)
self.text_embedding = nn.Sequential(
operations.Linear(text_dim, dim, device=device, dtype=dtype),
nn.GELU(approximate='tanh'),
operations.Linear(dim, dim, device=device, dtype=dtype))
self.time_embedding = nn.Sequential(
operations.Linear(freq_dim, dim, device=device, dtype=dtype),
nn.SiLU(),
operations.Linear(dim, dim, device=device, dtype=dtype))
self.time_projection = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, dim * 6, device=device, dtype=dtype))
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
CausalWanAttentionBlock(
cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
operation_settings=operation_settings)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
d = dim // num_heads
self.rope_embedder = EmbedND(
dim=d, theta=10000.0,
axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
else:
self.img_emb = None
self.ref_conv = None
def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] += torch.linspace(
t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype
).reshape(-1, 1, 1)
img_ids[:, :, :, 1] += torch.linspace(
0, h_len - 1, steps=h_len, device=device, dtype=dtype
).reshape(1, -1, 1)
img_ids[:, :, :, 2] += torch.linspace(
0, w_len - 1, steps=w_len, device=device, dtype=dtype
).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
return self.rope_embedder(img_ids).movedim(1, 2)
def forward_block(self, x, timestep, context, start_frame, def forward_block(self, x, timestep, context, start_frame,
kv_caches, crossattn_caches, clip_fea=None): kv_caches, crossattn_caches, clip_fea=None):
@ -275,11 +189,11 @@ class CausalWanModel(torch.nn.Module):
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
x = self.patch_embedding(x) x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:] grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
# Per-frame time embedding → [B, block_frames, 6, dim] # Per-frame time embedding
e = self.time_embedding( e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1]) e = e.reshape(timestep.shape[0], -1, e.shape[-1])
@ -311,14 +225,6 @@ class CausalWanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w] return x[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
c = self.out_dim
b = x.shape[0]
u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u
def init_kv_caches(self, batch_size, max_seq_len, device, dtype): def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
"""Create fresh KV caches for all layers.""" """Create fresh KV caches for all layers."""
caches = [] caches = []
@ -365,39 +271,6 @@ class CausalWanModel(torch.nn.Module):
clip_fea=clip_fea, clip_fea=clip_fea,
) )
bs, c, t, h, w = x.shape return super().forward(x, timestep, context, clip_fea=clip_fea,
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) time_dim_concat=time_dim_concat,
transformer_options=transformer_options, **kwargs)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
x = self.patch_embedding(x)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
for block in self.blocks:
x = block(x, e=e0, freqs=freqs, context=context,
context_img_len=context_img_len,
transformer_options=transformer_options)
x = self.head(x, e)
x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w]