mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 21:43:43 +08:00
Refactor CausalWanModel to inherit from WanModel.
This commit is contained in:
parent
3a9547192e
commit
e649a3bc72
@ -9,19 +9,16 @@ block at a time and maintains a KV cache across blocks.
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
WAN_CROSSATTENTION_CLASSES,
|
||||
Head,
|
||||
MLPProj,
|
||||
repeat_e,
|
||||
WanModel,
|
||||
WanAttentionBlock,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
@ -87,33 +84,18 @@ class CausalWanSelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(nn.Module):
|
||||
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
|
||||
self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||
self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
|
||||
dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings)
|
||||
self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
|
||||
self.ffn = nn.Sequential(
|
||||
ops.Linear(dim, ffn_dim, device=device, dtype=dtype),
|
||||
nn.GELU(approximate='tanh'),
|
||||
ops.Linear(ffn_dim, dim, device=device, dtype=dtype))
|
||||
|
||||
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype))
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
self.self_attn = CausalWanSelfAttention(
|
||||
dim, num_heads, window_size, qk_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
@ -150,7 +132,7 @@ class CausalWanAttentionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(torch.nn.Module):
|
||||
class CausalWanModel(WanModel):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
@ -178,82 +160,14 @@ class CausalWanModel(torch.nn.Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.model_type = model_type
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size,
|
||||
device=device, dtype=dtype)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(text_dim, dim, device=device, dtype=dtype),
|
||||
nn.GELU(approximate='tanh'),
|
||||
operations.Linear(dim, dim, device=device, dtype=dtype))
|
||||
self.time_embedding = nn.Sequential(
|
||||
operations.Linear(freq_dim, dim, device=device, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, dim, device=device, dtype=dtype))
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, dim * 6, device=device, dtype=dtype))
|
||||
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
CausalWanAttentionBlock(
|
||||
cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
d = dim // num_heads
|
||||
self.rope_embedder = EmbedND(
|
||||
dim=d, theta=10000.0,
|
||||
axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
self.ref_conv = None
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] += torch.linspace(
|
||||
t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype
|
||||
).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] += torch.linspace(
|
||||
0, h_len - 1, steps=h_len, device=device, dtype=dtype
|
||||
).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] += torch.linspace(
|
||||
0, w_len - 1, steps=w_len, device=device, dtype=dtype
|
||||
).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
return self.rope_embedder(img_ids).movedim(1, 2)
|
||||
super().__init__(
|
||||
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||
wan_attn_block_class=CausalWanAttentionBlock,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
@ -275,11 +189,11 @@ class CausalWanModel(torch.nn.Module):
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x)
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding → [B, block_frames, 6, dim]
|
||||
# Per-frame time embedding
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
@ -311,14 +225,6 @@ class CausalWanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
c = self.out_dim
|
||||
b = x.shape[0]
|
||||
u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
@ -365,39 +271,6 @@ class CausalWanModel(torch.nn.Module):
|
||||
clip_fea=clip_fea,
|
||||
)
|
||||
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
x = self.patch_embedding(x)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
x = self.head(x, e)
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||
time_dim_concat=time_dim_concat,
|
||||
transformer_options=transformer_options, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user