mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Initial commit causual_forcing.
This commit is contained in:
parent
b53b10ea61
commit
c0de57725b
392
comfy/ldm/wan/causal_model.py
Normal file
392
comfy/ldm/wan/causal_model.py
Normal file
@ -0,0 +1,392 @@
|
||||
"""
|
||||
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||
|
||||
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||
The difference is purely in the forward pass: this model processes one temporal
|
||||
block at a time and maintains a KV cache across blocks.
|
||||
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
WanT2VCrossAttention,
|
||||
WAN_CROSSATTENTION_CLASSES,
|
||||
Head,
|
||||
MLPProj,
|
||||
repeat_e,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CausalWanSelfAttention(nn.Module):
|
||||
"""Self-attention with KV cache support for autoregressive inference."""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
|
||||
if kv_cache is None:
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v.view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"].item()
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"].fill_(new_end)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(nn.Module):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
|
||||
self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||
self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
|
||||
dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings)
|
||||
self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype)
|
||||
self.ffn = nn.Sequential(
|
||||
ops.Linear(dim, ffn_dim, device=device, dtype=dtype),
|
||||
nn.GELU(approximate='tanh'),
|
||||
ops.Linear(ffn_dim, dim, device=device, dtype=dtype))
|
||||
|
||||
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
|
||||
# Self-attention with optional KV cache
|
||||
x = x.contiguous()
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# Cross-attention with optional caching
|
||||
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||
x_ca = optimized_attention(
|
||||
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||
heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn.o(x_ca)
|
||||
else:
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if crossattn_cache is not None:
|
||||
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||
crossattn_cache["is_init"] = True
|
||||
|
||||
# FFN
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(torch.nn.Module):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
Same weight structure as WanModel -- loads identical state dicts.
|
||||
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.model_type = model_type
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size,
|
||||
device=device, dtype=dtype)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(text_dim, dim, device=device, dtype=dtype),
|
||||
nn.GELU(approximate='tanh'),
|
||||
operations.Linear(dim, dim, device=device, dtype=dtype))
|
||||
self.time_embedding = nn.Sequential(
|
||||
operations.Linear(freq_dim, dim, device=device, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, dim, device=device, dtype=dtype))
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, dim * 6, device=device, dtype=dtype))
|
||||
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
CausalWanAttentionBlock(
|
||||
cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
d = dim // num_heads
|
||||
self.rope_embedder = EmbedND(
|
||||
dim=d, theta=10000.0,
|
||||
axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
self.ref_conv = None
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] += torch.linspace(
|
||||
t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype
|
||||
).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] += torch.linspace(
|
||||
0, h_len - 1, steps=h_len, device=device, dtype=dtype
|
||||
).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] += torch.linspace(
|
||||
0, w_len - 1, steps=w_len, device=device, dtype=dtype
|
||||
).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
return self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
"""
|
||||
Forward one temporal block for autoregressive inference.
|
||||
|
||||
Args:
|
||||
x: [B, C, block_frames, H, W] input latent for the current block
|
||||
timestep: [B, block_frames] per-frame timesteps
|
||||
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||
start_frame: temporal frame index for RoPE offset
|
||||
kv_caches: list of per-layer KV cache dicts
|
||||
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||
clip_fea: optional CLIP features for I2V
|
||||
|
||||
Returns:
|
||||
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||
"""
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding → [B, block_frames, 6, dim]
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# Text embedding (reuses crossattn_cache after first block)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
# RoPE for current block's temporal position
|
||||
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
kv_cache=kv_caches[i],
|
||||
crossattn_cache=crossattn_caches[i])
|
||||
|
||||
# Head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
c = self.out_dim
|
||||
b = x.shape[0]
|
||||
u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
|
||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||
return u
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": torch.tensor([0], dtype=torch.long, device=device),
|
||||
})
|
||||
return caches
|
||||
|
||||
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||
"""Create fresh cross-attention caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({"is_init": False})
|
||||
return caches
|
||||
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(0)
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
for cache in crossattn_caches:
|
||||
cache["is_init"] = False
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.dim // self.num_heads
|
||||
|
||||
# Standard forward for non-causal use (compatibility with ComfyUI infrastructure)
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
x = self.patch_embedding(x)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
x = self.head(x, e)
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
272
comfy_extras/nodes_causal_forcing.py
Normal file
272
comfy_extras/nodes_causal_forcing.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""
|
||||
ComfyUI nodes for Causal Forcing autoregressive video generation.
|
||||
- LoadCausalForcingModel: load original HF/training or pre-converted checkpoints
|
||||
(auto-detects format and converts state dict at runtime)
|
||||
- CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.ldm.wan.causal_model import CausalWanModel
|
||||
from comfy.ldm.wan.causal_convert import extract_state_dict
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
||||
WAN_CONFIGS = {
|
||||
# dim → (ffn_dim, num_heads, num_layers, text_dim)
|
||||
1536: (8960, 12, 30, 4096), # 1.3B
|
||||
2048: (8192, 16, 32, 4096), # ~2B
|
||||
5120: (13824, 40, 40, 4096), # 14B
|
||||
}
|
||||
|
||||
|
||||
class LoadCausalForcingModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadCausalForcingModel",
|
||||
category="loaders/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="MODEL"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
|
||||
raw = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd = extract_state_dict(raw, use_ema=True)
|
||||
del raw
|
||||
|
||||
dim = sd["head.modulation"].shape[-1]
|
||||
out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim
|
||||
in_dim = sd["patch_embedding.weight"].shape[1]
|
||||
num_layers = 0
|
||||
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
|
||||
num_layers += 1
|
||||
|
||||
if dim in WAN_CONFIGS:
|
||||
ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim]
|
||||
else:
|
||||
num_heads = dim // 128
|
||||
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
||||
text_dim = 4096
|
||||
logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
|
||||
|
||||
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
model = CausalWanModel(
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=in_dim,
|
||||
dim=dim,
|
||||
ffn_dim=ffn_dim,
|
||||
freq_dim=256,
|
||||
text_dim=text_dim,
|
||||
out_dim=out_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=cross_attn_norm,
|
||||
eps=1e-6,
|
||||
device=offload_device,
|
||||
dtype=torch.bfloat16,
|
||||
operations=ops,
|
||||
)
|
||||
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
|
||||
model_size = comfy.model_management.module_size(model)
|
||||
patcher = ModelPatcher(model, load_device=load_device,
|
||||
offload_device=offload_device, size=model_size)
|
||||
patcher.model.latent_format = comfy.latent_formats.Wan21()
|
||||
return io.NodeOutput(patcher)
|
||||
|
||||
|
||||
class CausalForcingSampler(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CausalForcingSampler",
|
||||
category="sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("num_frames", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
|
||||
io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1),
|
||||
io.String.Input("denoising_steps", default="1000,750,500,250"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, seed, width, height,
|
||||
num_frames, num_frame_per_block, timestep_shift,
|
||||
denoising_steps) -> io.NodeOutput:
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Parse denoising steps
|
||||
step_values = [int(s.strip()) for s in denoising_steps.split(",")]
|
||||
|
||||
# Build scheduler sigmas (FlowMatch with shift)
|
||||
num_train_timesteps = 1000
|
||||
raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1]
|
||||
sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
|
||||
# Warp denoising step indices to actual timestep values
|
||||
all_timesteps = torch.cat([timesteps, torch.tensor([0.0])])
|
||||
warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)]
|
||||
|
||||
# Get the CausalWanModel from the patcher
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
causal_model = model.model
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Extract text embeddings from conditioning
|
||||
cond = positive[0][0].to(device=device, dtype=dtype)
|
||||
if cond.ndim == 2:
|
||||
cond = cond.unsqueeze(0)
|
||||
|
||||
# Latent dimensions
|
||||
lat_h = height // 8
|
||||
lat_w = width // 8
|
||||
lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression
|
||||
in_channels = 16
|
||||
|
||||
# Generate noise
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w,
|
||||
generator=generator, device="cpu").to(device=device, dtype=dtype)
|
||||
|
||||
assert lat_t % num_frame_per_block == 0, \
|
||||
f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})"
|
||||
num_blocks = lat_t // num_frame_per_block
|
||||
|
||||
# Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch
|
||||
frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2)
|
||||
max_seq_len = lat_t * frame_seq_len
|
||||
|
||||
# Initialize caches
|
||||
kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype)
|
||||
crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype)
|
||||
|
||||
output = torch.zeros_like(noise)
|
||||
pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks)
|
||||
|
||||
current_start_frame = 0
|
||||
for block_idx in range(num_blocks):
|
||||
block_frames = num_frame_per_block
|
||||
frame_start = current_start_frame
|
||||
frame_end = current_start_frame + block_frames
|
||||
|
||||
# Noise slice for this block: [B, C, block_frames, H, W]
|
||||
noisy_input = noise[:, :, frame_start:frame_end]
|
||||
|
||||
# Denoising loop (e.g. 4 steps)
|
||||
for step_idx, current_timestep in enumerate(warped_steps):
|
||||
t_val = current_timestep.item()
|
||||
|
||||
# Per-frame timestep tensor [B, block_frames]
|
||||
timestep_tensor = torch.full(
|
||||
(1, block_frames), t_val, device=device, dtype=dtype)
|
||||
|
||||
# Model forward
|
||||
flow_pred = causal_model.forward_block(
|
||||
x=noisy_input,
|
||||
timestep=timestep_tensor,
|
||||
context=cond,
|
||||
start_frame=current_start_frame,
|
||||
kv_caches=kv_caches,
|
||||
crossattn_caches=crossattn_caches,
|
||||
)
|
||||
|
||||
# x0 = input - sigma * flow_pred
|
||||
sigma_t = _lookup_sigma(sigmas, timesteps, t_val)
|
||||
denoised = noisy_input - sigma_t * flow_pred
|
||||
|
||||
if step_idx < len(warped_steps) - 1:
|
||||
# Add noise for next step
|
||||
next_t = warped_steps[step_idx + 1].item()
|
||||
sigma_next = _lookup_sigma(sigmas, timesteps, next_t)
|
||||
fresh_noise = torch.randn_like(denoised)
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
# Roll back KV cache end pointer so next step re-writes same positions
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len)
|
||||
else:
|
||||
noisy_input = denoised
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
output[:, :, frame_start:frame_end] = noisy_input
|
||||
|
||||
# Cache update: forward at t=0 with clean output to fill KV cache
|
||||
with torch.no_grad():
|
||||
# Reset cache end to before this block so the t=0 pass writes clean K/V
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len)
|
||||
|
||||
t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype)
|
||||
causal_model.forward_block(
|
||||
x=noisy_input,
|
||||
timestep=t_zero,
|
||||
context=cond,
|
||||
start_frame=current_start_frame,
|
||||
kv_caches=kv_caches,
|
||||
crossattn_caches=crossattn_caches,
|
||||
)
|
||||
|
||||
pbar.update(1)
|
||||
current_start_frame += block_frames
|
||||
|
||||
# Apply latent format scaling
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
output_scaled = latent_format.process_in(output.float().cpu())
|
||||
|
||||
return io.NodeOutput({"samples": output_scaled})
|
||||
|
||||
|
||||
def _lookup_sigma(sigmas, timesteps, t_val):
|
||||
"""Find the sigma corresponding to a timestep value."""
|
||||
idx = torch.argmin((timesteps - t_val).abs()).item()
|
||||
return sigmas[idx]
|
||||
|
||||
|
||||
class CausalForcingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadCausalForcingModel,
|
||||
CausalForcingSampler,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CausalForcingExtension:
|
||||
return CausalForcingExtension()
|
||||
Loading…
Reference in New Issue
Block a user