From c0de57725b28c60de1f0cd72d6ac6bdd3e8103c2 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 20 Mar 2026 15:01:27 +0100 Subject: [PATCH 01/12] Initial commit causual_forcing. --- comfy/ldm/wan/causal_model.py | 392 +++++++++++++++++++++++++++ comfy_extras/nodes_causal_forcing.py | 272 +++++++++++++++++++ nodes.py | 1 + 3 files changed, 665 insertions(+) create mode 100644 comfy/ldm/wan/causal_model.py create mode 100644 comfy_extras/nodes_causal_forcing.py diff --git a/comfy/ldm/wan/causal_model.py b/comfy/ldm/wan/causal_model.py new file mode 100644 index 000000000..268f7ac34 --- /dev/null +++ b/comfy/ldm/wan/causal_model.py @@ -0,0 +1,392 @@ +""" +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for +autoregressive (frame-by-frame) video generation via Causal Forcing. + +Weight-compatible with the standard WanModel -- same layer names, same shapes. +The difference is purely in the forward pass: this model processes one temporal +block at a time and maintains a KV cache across blocks. + +Reference: https://github.com/thu-ml/Causal-Forcing +""" + +import math +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import ( + sinusoidal_embedding_1d, + WanT2VCrossAttention, + WAN_CROSSATTENTION_CLASSES, + Head, + MLPProj, + repeat_e, +) +import comfy.ldm.common_dit +import comfy.model_management + + +class CausalWanSelfAttention(nn.Module): + """Self-attention with KV cache support for autoregressive inference.""" + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, + eps=1e-6, operation_settings={}): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + + def forward(self, x, freqs, kv_cache=None, transformer_options={}): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) + v = self.v(x).view(b, s, n, d) + + if kv_cache is None: + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v.view(b, s, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + else: + end = kv_cache["end"].item() + new_end = end + s + + # Roped K and plain V go into cache + kv_cache["k"][:, end:new_end] = k + kv_cache["v"][:, end:new_end] = v + kv_cache["end"].fill_(new_end) + + x = optimized_attention( + q.view(b, s, n * d), + kv_cache["k"][:, :new_end].view(b, new_end, n * d), + kv_cache["v"][:, :new_end].view(b, new_end, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + + x = self.o(x) + return x + + +class CausalWanAttentionBlock(nn.Module): + """Transformer block with KV-cached self-attention and cross-attention caching.""" + + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) + self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) + self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( + dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings) + self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) + self.ffn = nn.Sequential( + ops.Linear(dim, ffn_dim, device=device, dtype=dtype), + nn.GELU(approximate='tanh'), + ops.Linear(ffn_dim, dim, device=device, dtype=dtype)) + + self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype)) + + def forward(self, x, e, freqs, context, context_img_len=257, + kv_cache=None, crossattn_cache=None, transformer_options={}): + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + + # Self-attention with optional KV cache + x = x.contiguous() + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, kv_cache=kv_cache, transformer_options=transformer_options) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y + + # Cross-attention with optional caching + if crossattn_cache is not None and crossattn_cache.get("is_init"): + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) + x_ca = optimized_attention( + q, crossattn_cache["k"], crossattn_cache["v"], + heads=self.num_heads, transformer_options=transformer_options) + x = x + self.cross_attn.o(x_ca) + else: + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if crossattn_cache is not None: + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) + crossattn_cache["v"] = self.cross_attn.v(context) + crossattn_cache["is_init"] = True + + # FFN + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class CausalWanModel(torch.nn.Module): + """ + Wan 2.1 diffusion backbone with causal KV-cache support. + + Same weight structure as WanModel -- loads identical state dicts. + Adds forward_block() for frame-by-frame autoregressive inference. + """ + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + image_model=None, + device=None, + dtype=None, + operations=None): + super().__init__() + self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.model_type = model_type + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + self.patch_embedding = operations.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size, + device=device, dtype=dtype) + self.text_embedding = nn.Sequential( + operations.Linear(text_dim, dim, device=device, dtype=dtype), + nn.GELU(approximate='tanh'), + operations.Linear(dim, dim, device=device, dtype=dtype)) + self.time_embedding = nn.Sequential( + operations.Linear(freq_dim, dim, device=device, dtype=dtype), + nn.SiLU(), + operations.Linear(dim, dim, device=device, dtype=dtype)) + self.time_projection = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, dim * 6, device=device, dtype=dtype)) + + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + CausalWanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + for _ in range(num_layers) + ]) + + self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) + + d = dim // num_heads + self.rope_embedder = EmbedND( + dim=d, theta=10000.0, + axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + else: + self.img_emb = None + + self.ref_conv = None + + def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None): + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] += torch.linspace( + t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype + ).reshape(-1, 1, 1) + img_ids[:, :, :, 1] += torch.linspace( + 0, h_len - 1, steps=h_len, device=device, dtype=dtype + ).reshape(1, -1, 1) + img_ids[:, :, :, 2] += torch.linspace( + 0, w_len - 1, steps=w_len, device=device, dtype=dtype + ).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + return self.rope_embedder(img_ids).movedim(1, 2) + + def forward_block(self, x, timestep, context, start_frame, + kv_caches, crossattn_caches, clip_fea=None): + """ + Forward one temporal block for autoregressive inference. + + Args: + x: [B, C, block_frames, H, W] input latent for the current block + timestep: [B, block_frames] per-frame timesteps + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) + start_frame: temporal frame index for RoPE offset + kv_caches: list of per-layer KV cache dicts + crossattn_caches: list of per-layer cross-attention cache dicts + clip_fea: optional CLIP features for I2V + + Returns: + flow_pred: [B, C_out, block_frames, H, W] flow prediction + """ + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + bs, c, t, h, w = x.shape + + x = self.patch_embedding(x) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # Per-frame time embedding → [B, block_frames, 6, dim] + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # Text embedding (reuses crossattn_cache after first block) + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + # RoPE for current block's temporal position + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) + + # Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + kv_cache=kv_caches[i], + crossattn_cache=crossattn_caches[i]) + + # Head + x = self.head(x, e) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + b = x.shape[0] + u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) + u = torch.einsum('bfhwpqrc->bcfphqwr', u) + u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + return u + + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): + """Create fresh KV caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({ + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "end": torch.tensor([0], dtype=torch.long, device=device), + }) + return caches + + def init_crossattn_caches(self, batch_size, device, dtype): + """Create fresh cross-attention caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({"is_init": False}) + return caches + + def reset_kv_caches(self, kv_caches): + """Reset KV caches to empty (reuse allocated memory).""" + for cache in kv_caches: + cache["end"].fill_(0) + + def reset_crossattn_caches(self, crossattn_caches): + """Reset cross-attention caches.""" + for cache in crossattn_caches: + cache["is_init"] = False + + @property + def head_dim(self): + return self.dim // self.num_heads + + # Standard forward for non-causal use (compatibility with ComfyUI infrastructure) + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + x = self.patch_embedding(x) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) + + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + for block in self.blocks: + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + transformer_options=transformer_options) + + x = self.head(x, e) + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] diff --git a/comfy_extras/nodes_causal_forcing.py b/comfy_extras/nodes_causal_forcing.py new file mode 100644 index 000000000..3646cec37 --- /dev/null +++ b/comfy_extras/nodes_causal_forcing.py @@ -0,0 +1,272 @@ +""" +ComfyUI nodes for Causal Forcing autoregressive video generation. + - LoadCausalForcingModel: load original HF/training or pre-converted checkpoints + (auto-detects format and converts state dict at runtime) + - CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache +""" + +import torch +import logging +import folder_paths +from typing_extensions import override + +import comfy.model_management +import comfy.utils +import comfy.ops +import comfy.latent_formats +from comfy.model_patcher import ModelPatcher +from comfy.ldm.wan.causal_model import CausalWanModel +from comfy.ldm.wan.causal_convert import extract_state_dict +from comfy_api.latest import ComfyExtension, io + +# ── Model size presets derived from Wan 2.1 configs ────────────────────────── +WAN_CONFIGS = { + # dim → (ffn_dim, num_heads, num_layers, text_dim) + 1536: (8960, 12, 30, 4096), # 1.3B + 2048: (8192, 16, 32, 4096), # ~2B + 5120: (13824, 40, 40, 4096), # 14B +} + + +class LoadCausalForcingModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadCausalForcingModel", + category="loaders/video_models", + inputs=[ + io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), + ], + outputs=[ + io.Model.Output(display_name="MODEL"), + ], + ) + + @classmethod + def execute(cls, ckpt_name) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) + raw = comfy.utils.load_torch_file(ckpt_path) + sd = extract_state_dict(raw, use_ema=True) + del raw + + dim = sd["head.modulation"].shape[-1] + out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim + in_dim = sd["patch_embedding.weight"].shape[1] + num_layers = 0 + while f"blocks.{num_layers}.self_attn.q.weight" in sd: + num_layers += 1 + + if dim in WAN_CONFIGS: + ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim] + else: + num_heads = dim // 128 + ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] + text_dim = 4096 + logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") + + cross_attn_norm = "blocks.0.norm3.weight" in sd + + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + ops = comfy.ops.disable_weight_init + + model = CausalWanModel( + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=in_dim, + dim=dim, + ffn_dim=ffn_dim, + freq_dim=256, + text_dim=text_dim, + out_dim=out_dim, + num_heads=num_heads, + num_layers=num_layers, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=cross_attn_norm, + eps=1e-6, + device=offload_device, + dtype=torch.bfloat16, + operations=ops, + ) + + model.load_state_dict(sd, strict=False) + model.eval() + + model_size = comfy.model_management.module_size(model) + patcher = ModelPatcher(model, load_device=load_device, + offload_device=offload_device, size=model_size) + patcher.model.latent_format = comfy.latent_formats.Wan21() + return io.NodeOutput(patcher) + + +class CausalForcingSampler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CausalForcingSampler", + category="sampling", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("num_frames", default=81, min=1, max=1024, step=4), + io.Int.Input("num_frame_per_block", default=1, min=1, max=21), + io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1), + io.String.Input("denoising_steps", default="1000,750,500,250"), + ], + outputs=[ + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, model, positive, seed, width, height, + num_frames, num_frame_per_block, timestep_shift, + denoising_steps) -> io.NodeOutput: + + device = comfy.model_management.get_torch_device() + + # Parse denoising steps + step_values = [int(s.strip()) for s in denoising_steps.split(",")] + + # Build scheduler sigmas (FlowMatch with shift) + num_train_timesteps = 1000 + raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1] + sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas) + timesteps = sigmas * num_train_timesteps + + # Warp denoising step indices to actual timestep values + all_timesteps = torch.cat([timesteps, torch.tensor([0.0])]) + warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)] + + # Get the CausalWanModel from the patcher + comfy.model_management.load_model_gpu(model) + causal_model = model.model + dtype = torch.bfloat16 + + # Extract text embeddings from conditioning + cond = positive[0][0].to(device=device, dtype=dtype) + if cond.ndim == 2: + cond = cond.unsqueeze(0) + + # Latent dimensions + lat_h = height // 8 + lat_w = width // 8 + lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression + in_channels = 16 + + # Generate noise + generator = torch.Generator(device="cpu").manual_seed(seed) + noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w, + generator=generator, device="cpu").to(device=device, dtype=dtype) + + assert lat_t % num_frame_per_block == 0, \ + f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})" + num_blocks = lat_t // num_frame_per_block + + # Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch + frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2) + max_seq_len = lat_t * frame_seq_len + + # Initialize caches + kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype) + crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype) + + output = torch.zeros_like(noise) + pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks) + + current_start_frame = 0 + for block_idx in range(num_blocks): + block_frames = num_frame_per_block + frame_start = current_start_frame + frame_end = current_start_frame + block_frames + + # Noise slice for this block: [B, C, block_frames, H, W] + noisy_input = noise[:, :, frame_start:frame_end] + + # Denoising loop (e.g. 4 steps) + for step_idx, current_timestep in enumerate(warped_steps): + t_val = current_timestep.item() + + # Per-frame timestep tensor [B, block_frames] + timestep_tensor = torch.full( + (1, block_frames), t_val, device=device, dtype=dtype) + + # Model forward + flow_pred = causal_model.forward_block( + x=noisy_input, + timestep=timestep_tensor, + context=cond, + start_frame=current_start_frame, + kv_caches=kv_caches, + crossattn_caches=crossattn_caches, + ) + + # x0 = input - sigma * flow_pred + sigma_t = _lookup_sigma(sigmas, timesteps, t_val) + denoised = noisy_input - sigma_t * flow_pred + + if step_idx < len(warped_steps) - 1: + # Add noise for next step + next_t = warped_steps[step_idx + 1].item() + sigma_next = _lookup_sigma(sigmas, timesteps, next_t) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + # Roll back KV cache end pointer so next step re-writes same positions + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) + else: + noisy_input = denoised + + pbar.update(1) + + output[:, :, frame_start:frame_end] = noisy_input + + # Cache update: forward at t=0 with clean output to fill KV cache + with torch.no_grad(): + # Reset cache end to before this block so the t=0 pass writes clean K/V + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) + + t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype) + causal_model.forward_block( + x=noisy_input, + timestep=t_zero, + context=cond, + start_frame=current_start_frame, + kv_caches=kv_caches, + crossattn_caches=crossattn_caches, + ) + + pbar.update(1) + current_start_frame += block_frames + + # Apply latent format scaling + latent_format = comfy.latent_formats.Wan21() + output_scaled = latent_format.process_in(output.float().cpu()) + + return io.NodeOutput({"samples": output_scaled}) + + +def _lookup_sigma(sigmas, timesteps, t_val): + """Find the sigma corresponding to a timestep value.""" + idx = torch.argmin((timesteps - t_val).abs()).item() + return sigmas[idx] + + +class CausalForcingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadCausalForcingModel, + CausalForcingSampler, + ] + + +async def comfy_entrypoint() -> CausalForcingExtension: + return CausalForcingExtension() diff --git a/nodes.py b/nodes.py index 37ceac2fc..66528c24d 100644 --- a/nodes.py +++ b/nodes.py @@ -2443,6 +2443,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_causal_forcing.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py", From 0836390c27ebf1c3f193e12c0d391073ae9c91f7 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 20 Mar 2026 20:37:49 +0100 Subject: [PATCH 02/12] Fix CausalForcingSampler. --- comfy_extras/nodes_causal_forcing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_causal_forcing.py b/comfy_extras/nodes_causal_forcing.py index 3646cec37..23c7049a4 100644 --- a/comfy_extras/nodes_causal_forcing.py +++ b/comfy_extras/nodes_causal_forcing.py @@ -246,11 +246,10 @@ class CausalForcingSampler(io.ComfyNode): pbar.update(1) current_start_frame += block_frames - # Apply latent format scaling + # Denormalize latents because VAEDecode expects raw latents. latent_format = comfy.latent_formats.Wan21() - output_scaled = latent_format.process_in(output.float().cpu()) - - return io.NodeOutput({"samples": output_scaled}) + output_denorm = latent_format.process_out(output.float().cpu()) + return io.NodeOutput({"samples": output_denorm}) def _lookup_sigma(sigmas, timesteps, t_val): From 2f30a821c59b688d44638e58b8faee5765fdcc59 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 20 Mar 2026 21:05:23 +0100 Subject: [PATCH 03/12] Rename causual forcing to using more general auto regressive naming convention. --- comfy/ldm/wan/ar_convert.py | 70 +++++++++++++++++++ .../ldm/wan/{causal_model.py => ar_model.py} | 0 ...es_causal_forcing.py => nodes_ar_video.py} | 30 ++++---- nodes.py | 2 +- 4 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 comfy/ldm/wan/ar_convert.py rename comfy/ldm/wan/{causal_model.py => ar_model.py} (100%) rename comfy_extras/{nodes_causal_forcing.py => nodes_ar_video.py} (92%) diff --git a/comfy/ldm/wan/ar_convert.py b/comfy/ldm/wan/ar_convert.py new file mode 100644 index 000000000..bec5c1014 --- /dev/null +++ b/comfy/ldm/wan/ar_convert.py @@ -0,0 +1,70 @@ +""" +State dict conversion for Causal Forcing checkpoints. + +Handles three checkpoint layouts: + 1. Training checkpoint with top-level generator_ema / generator keys + 2. Already-flattened state dict with model.* prefixes + 3. Already-converted ComfyUI state dict (bare model keys) + +Strips prefixes so the result matches the standard Wan 2.1 / CausalWanModel key layout +(e.g. blocks.0.self_attn.q.weight, head.modulation, etc.) +""" + +import logging + +log = logging.getLogger(__name__) + +PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."] + +_MODEL_KEY_PREFIXES = ( + "blocks.", "head.", "patch_embedding.", "text_embedding.", + "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.", +) + + +def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict: + """ + Extract and clean a Causal Forcing state dict from a training checkpoint. + + Returns a state dict with keys matching the CausalWanModel / WanModel layout. + """ + # Case 3: already converted -- keys are bare model keys + if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict: + return state_dict + + # Case 1: training checkpoint with wrapper key + raw_sd = None + order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"] + for wrapper_key in order: + if wrapper_key in state_dict: + raw_sd = state_dict[wrapper_key] + log.info("Causal Forcing: extracted '%s' with %d keys", wrapper_key, len(raw_sd)) + break + + # Case 2: flat dict with model.* prefixes + if raw_sd is None: + if any(k.startswith("model.") for k in state_dict): + raw_sd = state_dict + else: + raise KeyError( + f"Cannot detect Causal Forcing checkpoint layout. " + f"Top-level keys: {list(state_dict.keys())[:20]}" + ) + + out_sd = {} + for k, v in raw_sd.items(): + new_k = k + for prefix in PREFIXES_TO_STRIP: + if new_k.startswith(prefix): + new_k = new_k[len(prefix):] + break + else: + if not new_k.startswith(_MODEL_KEY_PREFIXES): + log.debug("Causal Forcing: skipping non-model key: %s", k) + continue + out_sd[new_k] = v + + if "head.modulation" not in out_sd: + raise ValueError("Conversion failed: 'head.modulation' not found in output keys") + + return out_sd diff --git a/comfy/ldm/wan/causal_model.py b/comfy/ldm/wan/ar_model.py similarity index 100% rename from comfy/ldm/wan/causal_model.py rename to comfy/ldm/wan/ar_model.py diff --git a/comfy_extras/nodes_causal_forcing.py b/comfy_extras/nodes_ar_video.py similarity index 92% rename from comfy_extras/nodes_causal_forcing.py rename to comfy_extras/nodes_ar_video.py index 23c7049a4..08010a6ac 100644 --- a/comfy_extras/nodes_causal_forcing.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,8 +1,8 @@ """ -ComfyUI nodes for Causal Forcing autoregressive video generation. - - LoadCausalForcingModel: load original HF/training or pre-converted checkpoints +ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). + - LoadARVideoModel: load original HF/training or pre-converted checkpoints (auto-detects format and converts state dict at runtime) - - CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache + - ARVideoSampler: autoregressive frame-by-frame sampling with KV cache """ import torch @@ -15,8 +15,8 @@ import comfy.utils import comfy.ops import comfy.latent_formats from comfy.model_patcher import ModelPatcher -from comfy.ldm.wan.causal_model import CausalWanModel -from comfy.ldm.wan.causal_convert import extract_state_dict +from comfy.ldm.wan.ar_model import CausalWanModel +from comfy.ldm.wan.ar_convert import extract_state_dict from comfy_api.latest import ComfyExtension, io # ── Model size presets derived from Wan 2.1 configs ────────────────────────── @@ -28,11 +28,11 @@ WAN_CONFIGS = { } -class LoadCausalForcingModel(io.ComfyNode): +class LoadARVideoModel(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="LoadCausalForcingModel", + node_id="LoadARVideoModel", category="loaders/video_models", inputs=[ io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), @@ -62,7 +62,7 @@ class LoadCausalForcingModel(io.ComfyNode): num_heads = dim // 128 ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] text_dim = 4096 - logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") + logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") cross_attn_norm = "blocks.0.norm3.weight" in sd @@ -101,11 +101,11 @@ class LoadCausalForcingModel(io.ComfyNode): return io.NodeOutput(patcher) -class CausalForcingSampler(io.ComfyNode): +class ARVideoSampler(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="CausalForcingSampler", + node_id="ARVideoSampler", category="sampling", inputs=[ io.Model.Input("model"), @@ -258,14 +258,14 @@ def _lookup_sigma(sigmas, timesteps, t_val): return sigmas[idx] -class CausalForcingExtension(ComfyExtension): +class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - LoadCausalForcingModel, - CausalForcingSampler, + LoadARVideoModel, + ARVideoSampler, ] -async def comfy_entrypoint() -> CausalForcingExtension: - return CausalForcingExtension() +async def comfy_entrypoint() -> ARVideoExtension: + return ARVideoExtension() diff --git a/nodes.py b/nodes.py index 66528c24d..4d674617f 100644 --- a/nodes.py +++ b/nodes.py @@ -2443,7 +2443,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", - "nodes_causal_forcing.py", + "nodes_ar_video.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py", From 6f9af338ae1fb5ee4893eebdeaf0edcaaae56ebe Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 20 Mar 2026 22:26:42 +0100 Subject: [PATCH 04/12] Apply ruff. --- comfy/ldm/wan/ar_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 268f7ac34..775d675b7 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -18,7 +18,6 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.wan.model import ( sinusoidal_embedding_1d, - WanT2VCrossAttention, WAN_CROSSATTENTION_CLASSES, Head, MLPProj, From 3a9547192edb6e4a704850c5b6454aabdcfda39a Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 24 Mar 2026 13:23:06 +0100 Subject: [PATCH 05/12] Rewrite causual forcing using custom sampler with KSampler node. --- comfy/k_diffusion/sampling.py | 81 ++++++++++++ comfy/ldm/wan/ar_model.py | 18 ++- comfy/model_base.py | 8 ++ comfy/samplers.py | 3 +- comfy/supported_models.py | 9 ++ comfy_extras/nodes_ar_video.py | 233 ++++++++------------------------- 6 files changed, 170 insertions(+), 182 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..870bff369 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1810,3 +1810,84 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + + +@torch.no_grad() +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None): + """ + Autoregressive video sampler: block-by-block denoising with KV cache + and flow-match re-noising for Causal Forcing / Self-Forcing models. + """ + extra_args = {} if extra_args is None else extra_args + model_options = extra_args.get("model_options", {}) + transformer_options = model_options.get("transformer_options", {}) + ar_config = transformer_options.get("ar_config", {}) + + num_frame_per_block = ar_config.get("num_frame_per_block", 1) + seed = extra_args.get("seed", 0) + + bs, c, lat_t, lat_h, lat_w = x.shape + frame_seq_len = (lat_h // 2) * (lat_w // 2) + num_blocks = lat_t // num_frame_per_block + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + device = x.device + model_dtype = inner_model.get_dtype() + + kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) + + output = torch.zeros_like(x) + s_in = x.new_ones([x.shape[0]]) + current_start_frame = 0 + num_sigma_steps = len(sigmas) - 1 + total_real_steps = num_blocks * num_sigma_steps + step_count = 0 + + for block_idx in trange(num_blocks, disable=disable): + bf = num_frame_per_block + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] + + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state + + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + + if callback is not None: + # Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) + + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + + step_count += 1 + + output[:, :, fs:fe] = noisy_input + + # Cache update: run model at t=0 with clean output to fill KV cache + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) + + current_start_frame += bf + + transformer_options.pop("ar_state", None) + return output diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 775d675b7..0fe2a585c 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -281,7 +281,7 @@ class CausalWanModel(torch.nn.Module): # Per-frame time embedding → [B, block_frames, 6, dim] e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) @@ -351,8 +351,20 @@ class CausalWanModel(torch.nn.Module): def head_dim(self): return self.dim // self.num_heads - # Standard forward for non-causal use (compatibility with ComfyUI infrastructure) def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + ar_state = transformer_options.get("ar_state") + if ar_state is not None: + bs = x.shape[0] + block_frames = x.shape[2] + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) + return self.forward_block( + x=x, timestep=t_per_frame, context=context, + start_frame=ar_state["start_frame"], + kv_caches=ar_state["kv_caches"], + crossattn_caches=ar_state["crossattn_caches"], + clip_fea=clip_fea, + ) + bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) @@ -369,7 +381,7 @@ class CausalWanModel(torch.nn.Module): freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) diff --git a/comfy/model_base.py b/comfy/model_base.py index 70aff886e..0bdfefba5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.wan.ar_model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1353,6 +1354,13 @@ class WAN21(BaseModel): return out +class WAN21_CausalAR(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.wan.ar_model.CausalWanModel) + self.image_to_video = False + + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..03a07ec68 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece", + "ar_video"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 07feb31b3..4c5159fbe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1165,6 +1165,15 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) +class WAN21_CausalAR_T2V(WAN21_T2V): + sampling_settings = { + "shift": 5.0, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.WAN21_CausalAR(self, device=device) + + class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 08010a6ac..41bed9414 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,8 +1,7 @@ """ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - LoadARVideoModel: load original HF/training or pre-converted checkpoints - (auto-detects format and converts state dict at runtime) - - ARVideoSampler: autoregressive frame-by-frame sampling with KV cache + via the standard BaseModel + ModelPatcher pipeline """ import torch @@ -13,10 +12,9 @@ from typing_extensions import override import comfy.model_management import comfy.utils import comfy.ops -import comfy.latent_formats -from comfy.model_patcher import ModelPatcher -from comfy.ldm.wan.ar_model import CausalWanModel +import comfy.model_patcher from comfy.ldm.wan.ar_convert import extract_state_dict +from comfy.supported_models import WAN21_CausalAR_T2V from comfy_api.latest import ComfyExtension, io # ── Model size presets derived from Wan 2.1 configs ────────────────────────── @@ -36,6 +34,7 @@ class LoadARVideoModel(io.ComfyNode): category="loaders/video_models", inputs=[ io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), + io.Int.Input("num_frame_per_block", default=1, min=1, max=21), ], outputs=[ io.Model.Output(display_name="MODEL"), @@ -43,21 +42,21 @@ class LoadARVideoModel(io.ComfyNode): ) @classmethod - def execute(cls, ckpt_name) -> io.NodeOutput: + def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) raw = comfy.utils.load_torch_file(ckpt_path) sd = extract_state_dict(raw, use_ema=True) del raw dim = sd["head.modulation"].shape[-1] - out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim + out_dim = sd["head.head.weight"].shape[0] // 4 in_dim = sd["patch_embedding.weight"].shape[1] num_layers = 0 while f"blocks.{num_layers}.self_attn.q.weight" in sd: num_layers += 1 if dim in WAN_CONFIGS: - ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim] + ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim] else: num_heads = dim // 128 ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] @@ -66,57 +65,60 @@ class LoadARVideoModel(io.ComfyNode): cross_attn_norm = "blocks.0.norm3.weight" in sd + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "dim": dim, + "ffn_dim": ffn_dim, + "num_heads": num_heads, + "num_layers": num_layers, + "in_dim": in_dim, + "out_dim": out_dim, + "text_dim": text_dim, + "cross_attn_norm": cross_attn_norm, + } + + model_config = WAN21_CausalAR_T2V(unet_config) + unet_dtype = comfy.model_management.unet_dtype( + model_params=comfy.utils.calculate_parameters(sd), + supported_dtypes=model_config.supported_inference_dtypes, + ) + manual_cast_dtype = comfy.model_management.unet_manual_cast( + unet_dtype, + comfy.model_management.get_torch_device(), + model_config.supported_inference_dtypes, + ) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + + model = model_config.get_model(sd, "") load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.unet_offload_device() - ops = comfy.ops.disable_weight_init - model = CausalWanModel( - model_type='t2v', - patch_size=(1, 2, 2), - text_len=512, - in_dim=in_dim, - dim=dim, - ffn_dim=ffn_dim, - freq_dim=256, - text_dim=text_dim, - out_dim=out_dim, - num_heads=num_heads, - num_layers=num_layers, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=cross_attn_norm, - eps=1e-6, - device=offload_device, - dtype=torch.bfloat16, - operations=ops, + model_patcher = comfy.model_patcher.ModelPatcher( + model, load_device=load_device, offload_device=offload_device, ) + if not comfy.model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(sd, "") - model.load_state_dict(sd, strict=False) - model.eval() + model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = { + "num_frame_per_block": num_frame_per_block, + } - model_size = comfy.model_management.module_size(model) - patcher = ModelPatcher(model, load_device=load_device, - offload_device=offload_device, size=model_size) - patcher.model.latent_format = comfy.latent_formats.Wan21() - return io.NodeOutput(patcher) + return io.NodeOutput(model_patcher) -class ARVideoSampler(io.ComfyNode): +class EmptyARVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="ARVideoSampler", - category="sampling", + node_id="EmptyARVideoLatent", + category="latent/video", inputs=[ - io.Model.Input("model"), - io.Conditioning.Input("positive"), - io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), io.Int.Input("width", default=832, min=16, max=8192, step=16), io.Int.Input("height", default=480, min=16, max=8192, step=16), - io.Int.Input("num_frames", default=81, min=1, max=1024, step=4), - io.Int.Input("num_frame_per_block", default=1, min=1, max=21), - io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1), - io.String.Input("denoising_steps", default="1000,750,500,250"), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), ], outputs=[ io.Latent.Output(display_name="LATENT"), @@ -124,138 +126,13 @@ class ARVideoSampler(io.ComfyNode): ) @classmethod - def execute(cls, model, positive, seed, width, height, - num_frames, num_frame_per_block, timestep_shift, - denoising_steps) -> io.NodeOutput: - - device = comfy.model_management.get_torch_device() - - # Parse denoising steps - step_values = [int(s.strip()) for s in denoising_steps.split(",")] - - # Build scheduler sigmas (FlowMatch with shift) - num_train_timesteps = 1000 - raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1] - sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas) - timesteps = sigmas * num_train_timesteps - - # Warp denoising step indices to actual timestep values - all_timesteps = torch.cat([timesteps, torch.tensor([0.0])]) - warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)] - - # Get the CausalWanModel from the patcher - comfy.model_management.load_model_gpu(model) - causal_model = model.model - dtype = torch.bfloat16 - - # Extract text embeddings from conditioning - cond = positive[0][0].to(device=device, dtype=dtype) - if cond.ndim == 2: - cond = cond.unsqueeze(0) - - # Latent dimensions - lat_h = height // 8 - lat_w = width // 8 - lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression - in_channels = 16 - - # Generate noise - generator = torch.Generator(device="cpu").manual_seed(seed) - noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w, - generator=generator, device="cpu").to(device=device, dtype=dtype) - - assert lat_t % num_frame_per_block == 0, \ - f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})" - num_blocks = lat_t // num_frame_per_block - - # Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch - frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2) - max_seq_len = lat_t * frame_seq_len - - # Initialize caches - kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype) - crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype) - - output = torch.zeros_like(noise) - pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks) - - current_start_frame = 0 - for block_idx in range(num_blocks): - block_frames = num_frame_per_block - frame_start = current_start_frame - frame_end = current_start_frame + block_frames - - # Noise slice for this block: [B, C, block_frames, H, W] - noisy_input = noise[:, :, frame_start:frame_end] - - # Denoising loop (e.g. 4 steps) - for step_idx, current_timestep in enumerate(warped_steps): - t_val = current_timestep.item() - - # Per-frame timestep tensor [B, block_frames] - timestep_tensor = torch.full( - (1, block_frames), t_val, device=device, dtype=dtype) - - # Model forward - flow_pred = causal_model.forward_block( - x=noisy_input, - timestep=timestep_tensor, - context=cond, - start_frame=current_start_frame, - kv_caches=kv_caches, - crossattn_caches=crossattn_caches, - ) - - # x0 = input - sigma * flow_pred - sigma_t = _lookup_sigma(sigmas, timesteps, t_val) - denoised = noisy_input - sigma_t * flow_pred - - if step_idx < len(warped_steps) - 1: - # Add noise for next step - next_t = warped_steps[step_idx + 1].item() - sigma_next = _lookup_sigma(sigmas, timesteps, next_t) - fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise - - # Roll back KV cache end pointer so next step re-writes same positions - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) - else: - noisy_input = denoised - - pbar.update(1) - - output[:, :, frame_start:frame_end] = noisy_input - - # Cache update: forward at t=0 with clean output to fill KV cache - with torch.no_grad(): - # Reset cache end to before this block so the t=0 pass writes clean K/V - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) - - t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype) - causal_model.forward_block( - x=noisy_input, - timestep=t_zero, - context=cond, - start_frame=current_start_frame, - kv_caches=kv_caches, - crossattn_caches=crossattn_caches, - ) - - pbar.update(1) - current_start_frame += block_frames - - # Denormalize latents because VAEDecode expects raw latents. - latent_format = comfy.latent_formats.Wan21() - output_denorm = latent_format.process_out(output.float().cpu()) - return io.NodeOutput({"samples": output_denorm}) - - -def _lookup_sigma(sigmas, timesteps, t_val): - """Find the sigma corresponding to a timestep value.""" - idx = torch.argmin((timesteps - t_val).abs()).item() - return sigmas[idx] + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) class ARVideoExtension(ComfyExtension): @@ -263,7 +140,7 @@ class ARVideoExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ LoadARVideoModel, - ARVideoSampler, + EmptyARVideoLatent, ] From e649a3bc72b5a9b5b2cc9673fc0dab18f384d79c Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 17:53:59 +0100 Subject: [PATCH 06/12] Refactor CausalWanModel to inherit from WanModel. --- comfy/ldm/wan/ar_model.py | 173 +++++--------------------------------- 1 file changed, 23 insertions(+), 150 deletions(-) diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 0fe2a585c..54a2ef704 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -9,19 +9,16 @@ block at a time and maintains a KV cache across blocks. Reference: https://github.com/thu-ml/Causal-Forcing """ -import math import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.wan.model import ( sinusoidal_embedding_1d, - WAN_CROSSATTENTION_CLASSES, - Head, - MLPProj, repeat_e, + WanModel, + WanAttentionBlock, ) import comfy.ldm.common_dit import comfy.model_management @@ -87,33 +84,18 @@ class CausalWanSelfAttention(nn.Module): return x -class CausalWanAttentionBlock(nn.Module): +class CausalWanAttentionBlock(WanAttentionBlock): """Transformer block with KV-cached self-attention and cross-attention caching.""" def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6, operation_settings={}): - super().__init__() - self.dim = dim - self.ffn_dim = ffn_dim - self.num_heads = num_heads - - ops = operation_settings.get("operations") - device = operation_settings.get("device") - dtype = operation_settings.get("dtype") - - self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) - self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) - self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity() - self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( - dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings) - self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) - self.ffn = nn.Sequential( - ops.Linear(dim, ffn_dim, device=device, dtype=dtype), - nn.GELU(approximate='tanh'), - ops.Linear(ffn_dim, dim, device=device, dtype=dtype)) - - self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype)) + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + self.self_attn = CausalWanSelfAttention( + dim, num_heads, window_size, qk_norm, eps, + operation_settings=operation_settings) def forward(self, x, e, freqs, context, context_img_len=257, kv_cache=None, crossattn_cache=None, transformer_options={}): @@ -150,7 +132,7 @@ class CausalWanAttentionBlock(nn.Module): return x -class CausalWanModel(torch.nn.Module): +class CausalWanModel(WanModel): """ Wan 2.1 diffusion backbone with causal KV-cache support. @@ -178,82 +160,14 @@ class CausalWanModel(torch.nn.Module): device=None, dtype=None, operations=None): - super().__init__() - self.dtype = dtype - operation_settings = {"operations": operations, "device": device, "dtype": dtype} - - self.model_type = model_type - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - - self.patch_embedding = operations.Conv3d( - in_dim, dim, kernel_size=patch_size, stride=patch_size, - device=device, dtype=dtype) - self.text_embedding = nn.Sequential( - operations.Linear(text_dim, dim, device=device, dtype=dtype), - nn.GELU(approximate='tanh'), - operations.Linear(dim, dim, device=device, dtype=dtype)) - self.time_embedding = nn.Sequential( - operations.Linear(freq_dim, dim, device=device, dtype=dtype), - nn.SiLU(), - operations.Linear(dim, dim, device=device, dtype=dtype)) - self.time_projection = nn.Sequential( - nn.SiLU(), - operations.Linear(dim, dim * 6, device=device, dtype=dtype)) - - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' - self.blocks = nn.ModuleList([ - CausalWanAttentionBlock( - cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, - operation_settings=operation_settings) - for _ in range(num_layers) - ]) - - self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) - - d = dim // num_heads - self.rope_embedder = EmbedND( - dim=d, theta=10000.0, - axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) - - if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) - else: - self.img_emb = None - - self.ref_conv = None - - def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None): - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) - - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype) - img_ids[:, :, :, 0] += torch.linspace( - t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype - ).reshape(-1, 1, 1) - img_ids[:, :, :, 1] += torch.linspace( - 0, h_len - 1, steps=h_len, device=device, dtype=dtype - ).reshape(1, -1, 1) - img_ids[:, :, :, 2] += torch.linspace( - 0, w_len - 1, steps=w_len, device=device, dtype=dtype - ).reshape(1, 1, -1) - img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) - return self.rope_embedder(img_ids).movedim(1, 2) + super().__init__( + model_type=model_type, patch_size=patch_size, text_len=text_len, + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, + wan_attn_block_class=CausalWanAttentionBlock, + device=device, dtype=dtype, operations=operations) def forward_block(self, x, timestep, context, start_frame, kv_caches, crossattn_caches, clip_fea=None): @@ -275,11 +189,11 @@ class CausalWanModel(torch.nn.Module): x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) bs, c, t, h, w = x.shape - x = self.patch_embedding(x) + x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) - # Per-frame time embedding → [B, block_frames, 6, dim] + # Per-frame time embedding e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) @@ -311,14 +225,6 @@ class CausalWanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x[:, :, :t, :h, :w] - def unpatchify(self, x, grid_sizes): - c = self.out_dim - b = x.shape[0] - u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) - u = torch.einsum('bfhwpqrc->bcfphqwr', u) - u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) - return u - def init_kv_caches(self, batch_size, max_seq_len, device, dtype): """Create fresh KV caches for all layers.""" caches = [] @@ -365,39 +271,6 @@ class CausalWanModel(torch.nn.Module): clip_fea=clip_fea, ) - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - - t_len = t - if time_dim_concat is not None: - time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) - x = torch.cat([x, time_dim_concat], dim=2) - t_len = x.shape[2] - - x = self.patch_embedding(x) - grid_sizes = x.shape[2:] - x = x.flatten(2).transpose(1, 2) - - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) - - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) - e = e.reshape(timestep.shape[0], -1, e.shape[-1]) - e0 = self.time_projection(e).unflatten(2, (6, self.dim)) - - context = self.text_embedding(context) - - context_img_len = None - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) - context = torch.concat([context_clip, context], dim=1) - context_img_len = clip_fea.shape[-2] - - for block in self.blocks: - x = block(x, e=e0, freqs=freqs, context=context, - context_img_len=context_img_len, - transformer_options=transformer_options) - - x = self.head(x, e) - x = self.unpatchify(x, grid_sizes) - return x[:, :, :t, :h, :w] + return super().forward(x, timestep, context, clip_fea=clip_fea, + time_dim_concat=time_dim_concat, + transformer_options=transformer_options, **kwargs) From 4b2734889ce80a8317bde90364f40e02f1690388 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 20:39:37 +0100 Subject: [PATCH 07/12] Remove dedicated ARLoader. --- comfy/supported_models.py | 12 +++- comfy_extras/nodes_ar_video.py | 101 +-------------------------------- 2 files changed, 12 insertions(+), 101 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4c5159fbe..aa66e035f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1166,10 +1166,20 @@ class WAN21_T2V(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) class WAN21_CausalAR_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "causal_ar": True, + } + sampling_settings = { "shift": 5.0, } + def __init__(self, unet_config): + super().__init__(unet_config) + self.unet_config.pop("causal_ar", None) + def get_model(self, state_dict, prefix="", device=None): return model_base.WAN21_CausalAR(self, device=device) @@ -1743,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 41bed9414..be9f2eaec 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,112 +1,14 @@ """ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - - LoadARVideoModel: load original HF/training or pre-converted checkpoints - via the standard BaseModel + ModelPatcher pipeline + - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors """ import torch -import logging -import folder_paths from typing_extensions import override import comfy.model_management -import comfy.utils -import comfy.ops -import comfy.model_patcher -from comfy.ldm.wan.ar_convert import extract_state_dict -from comfy.supported_models import WAN21_CausalAR_T2V from comfy_api.latest import ComfyExtension, io -# ── Model size presets derived from Wan 2.1 configs ────────────────────────── -WAN_CONFIGS = { - # dim → (ffn_dim, num_heads, num_layers, text_dim) - 1536: (8960, 12, 30, 4096), # 1.3B - 2048: (8192, 16, 32, 4096), # ~2B - 5120: (13824, 40, 40, 4096), # 14B -} - - -class LoadARVideoModel(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="LoadARVideoModel", - category="loaders/video_models", - inputs=[ - io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), - io.Int.Input("num_frame_per_block", default=1, min=1, max=21), - ], - outputs=[ - io.Model.Output(display_name="MODEL"), - ], - ) - - @classmethod - def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput: - ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) - raw = comfy.utils.load_torch_file(ckpt_path) - sd = extract_state_dict(raw, use_ema=True) - del raw - - dim = sd["head.modulation"].shape[-1] - out_dim = sd["head.head.weight"].shape[0] // 4 - in_dim = sd["patch_embedding.weight"].shape[1] - num_layers = 0 - while f"blocks.{num_layers}.self_attn.q.weight" in sd: - num_layers += 1 - - if dim in WAN_CONFIGS: - ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim] - else: - num_heads = dim // 128 - ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] - text_dim = 4096 - logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") - - cross_attn_norm = "blocks.0.norm3.weight" in sd - - unet_config = { - "image_model": "wan2.1", - "model_type": "t2v", - "dim": dim, - "ffn_dim": ffn_dim, - "num_heads": num_heads, - "num_layers": num_layers, - "in_dim": in_dim, - "out_dim": out_dim, - "text_dim": text_dim, - "cross_attn_norm": cross_attn_norm, - } - - model_config = WAN21_CausalAR_T2V(unet_config) - unet_dtype = comfy.model_management.unet_dtype( - model_params=comfy.utils.calculate_parameters(sd), - supported_dtypes=model_config.supported_inference_dtypes, - ) - manual_cast_dtype = comfy.model_management.unet_manual_cast( - unet_dtype, - comfy.model_management.get_torch_device(), - model_config.supported_inference_dtypes, - ) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - - model = model_config.get_model(sd, "") - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.unet_offload_device() - - model_patcher = comfy.model_patcher.ModelPatcher( - model, load_device=load_device, offload_device=offload_device, - ) - if not comfy.model_management.is_device_cpu(offload_device): - model.to(offload_device) - model.load_model_weights(sd, "") - - model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = { - "num_frame_per_block": num_frame_per_block, - } - - return io.NodeOutput(model_patcher) - class EmptyARVideoLatent(io.ComfyNode): @classmethod @@ -139,7 +41,6 @@ class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - LoadARVideoModel, EmptyARVideoLatent, ] From de66e64ec23638d2b0ddf0ea93e5aa3ff211799c Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 21:13:23 +0100 Subject: [PATCH 08/12] Fix 'Process the tail block instead of truncating it', fix 'Don't mutate the patcher's shared transformer_options in place'. --- comfy/k_diffusion/sampling.py | 72 +++++++++++++++++------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 870bff369..646a6ae93 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1828,7 +1828,7 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No bs, c, lat_t, lat_h, lat_w = x.shape frame_seq_len = (lat_h // 2) * (lat_w // 2) - num_blocks = lat_t // num_frame_per_block + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division inner_model = model.inner_model.inner_model causal_model = inner_model.diffusion_model @@ -1845,49 +1845,49 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No total_real_steps = num_blocks * num_sigma_steps step_count = 0 - for block_idx in trange(num_blocks, disable=disable): - bf = num_frame_per_block - fs, fe = current_start_frame, current_start_frame + bf - noisy_input = x[:, :, fs:fe] + try: + for block_idx in trange(num_blocks, disable=disable): + bf = min(num_frame_per_block, lat_t - current_start_frame) + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] - ar_state = { - "start_frame": current_start_frame, - "kv_caches": kv_caches, - "crossattn_caches": crossattn_caches, - } - transformer_options["ar_state"] = ar_state + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state - for i in range(num_sigma_steps): - denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) - if callback is not None: - # Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually - scaled_i = step_count * num_sigma_steps // total_real_steps - callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], - "sigma_hat": sigmas[i], "denoised": denoised}) + if callback is not None: + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) - if sigmas[i + 1] == 0: - noisy_input = denoised - else: - sigma_next = sigmas[i + 1] - torch.manual_seed(seed + block_idx * 1000 + i) - fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) - step_count += 1 + step_count += 1 - output[:, :, fs:fe] = noisy_input + output[:, :, fs:fe] = noisy_input - # Cache update: run model at t=0 with clean output to fill KV cache - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) - zero_sigma = sigmas.new_zeros([1]) - _ = model(noisy_input, zero_sigma * s_in, **extra_args) + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) - current_start_frame += bf + current_start_frame += bf + finally: + transformer_options.pop("ar_state", None) - transformer_options.pop("ar_state", None) return output From 3440c57f67bee67a24db4917e02d9b03e74c22c1 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 21:44:17 +0100 Subject: [PATCH 09/12] Remove ar_convert, now present in hg repackaged model repo. --- comfy/ldm/wan/ar_convert.py | 70 ------------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 comfy/ldm/wan/ar_convert.py diff --git a/comfy/ldm/wan/ar_convert.py b/comfy/ldm/wan/ar_convert.py deleted file mode 100644 index bec5c1014..000000000 --- a/comfy/ldm/wan/ar_convert.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -State dict conversion for Causal Forcing checkpoints. - -Handles three checkpoint layouts: - 1. Training checkpoint with top-level generator_ema / generator keys - 2. Already-flattened state dict with model.* prefixes - 3. Already-converted ComfyUI state dict (bare model keys) - -Strips prefixes so the result matches the standard Wan 2.1 / CausalWanModel key layout -(e.g. blocks.0.self_attn.q.weight, head.modulation, etc.) -""" - -import logging - -log = logging.getLogger(__name__) - -PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."] - -_MODEL_KEY_PREFIXES = ( - "blocks.", "head.", "patch_embedding.", "text_embedding.", - "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.", -) - - -def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict: - """ - Extract and clean a Causal Forcing state dict from a training checkpoint. - - Returns a state dict with keys matching the CausalWanModel / WanModel layout. - """ - # Case 3: already converted -- keys are bare model keys - if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict: - return state_dict - - # Case 1: training checkpoint with wrapper key - raw_sd = None - order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"] - for wrapper_key in order: - if wrapper_key in state_dict: - raw_sd = state_dict[wrapper_key] - log.info("Causal Forcing: extracted '%s' with %d keys", wrapper_key, len(raw_sd)) - break - - # Case 2: flat dict with model.* prefixes - if raw_sd is None: - if any(k.startswith("model.") for k in state_dict): - raw_sd = state_dict - else: - raise KeyError( - f"Cannot detect Causal Forcing checkpoint layout. " - f"Top-level keys: {list(state_dict.keys())[:20]}" - ) - - out_sd = {} - for k, v in raw_sd.items(): - new_k = k - for prefix in PREFIXES_TO_STRIP: - if new_k.startswith(prefix): - new_k = new_k[len(prefix):] - break - else: - if not new_k.startswith(_MODEL_KEY_PREFIXES): - log.debug("Causal Forcing: skipping non-model key: %s", k) - continue - out_sd[new_k] = v - - if "head.modulation" not in out_sd: - raise ValueError("Conversion failed: 'head.modulation' not found in output keys") - - return out_sd From 08bf8f4d958dedbadbc768030b80c13862c25848 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 21:59:22 +0100 Subject: [PATCH 10/12] Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops. --- comfy/k_diffusion/sampling.py | 4 ++-- comfy/ldm/wan/ar_model.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 646a6ae93..5bab263bd 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1875,14 +1875,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + cache["end"] -= bf * frame_seq_len step_count += 1 output[:, :, fs:fe] = noisy_input for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + cache["end"] -= bf * frame_seq_len zero_sigma = sigmas.new_zeros([1]) _ = model(noisy_input, zero_sigma * s_in, **extra_args) diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 54a2ef704..d72f53602 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -64,13 +64,13 @@ class CausalWanSelfAttention(nn.Module): transformer_options=transformer_options, ) else: - end = kv_cache["end"].item() + end = kv_cache["end"] new_end = end + s # Roped K and plain V go into cache kv_cache["k"][:, end:new_end] = k kv_cache["v"][:, end:new_end] = v - kv_cache["end"].fill_(new_end) + kv_cache["end"] = new_end x = optimized_attention( q.view(b, s, n * d), @@ -232,7 +232,7 @@ class CausalWanModel(WanModel): caches.append({ "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), - "end": torch.tensor([0], dtype=torch.long, device=device), + "end": 0, }) return caches @@ -246,7 +246,7 @@ class CausalWanModel(WanModel): def reset_kv_caches(self, kv_caches): """Reset KV caches to empty (reuse allocated memory).""" for cache in kv_caches: - cache["end"].fill_(0) + cache["end"] = 0 def reset_crossattn_caches(self, crossattn_caches): """Reset cross-attention caches.""" From e9cf4659d2071afb47a16712d0df7681da293b07 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 22:05:12 +0100 Subject: [PATCH 11/12] Base frame_seq_len on the padded token grid. --- comfy/k_diffusion/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 5bab263bd..33c8552e3 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1827,8 +1827,8 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No seed = extra_args.get("seed", 0) bs, c, lat_t, lat_h, lat_w = x.shape - frame_seq_len = (lat_h // 2) * (lat_w // 2) - num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division inner_model = model.inner_model.inner_model causal_model = inner_model.diffusion_model From 28416847000abf97b520f8305d38512846dc1c2f Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 22:15:44 +0100 Subject: [PATCH 12/12] Add better error handling for a custom ar_video sampler. --- comfy/k_diffusion/sampling.py | 22 +++++++++++++++++++--- comfy/samplers.py | 3 +++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 33c8552e3..b1a8f80ab 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1817,21 +1817,37 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No """ Autoregressive video sampler: block-by-block denoising with KV cache and flow-match re-noising for Causal Forcing / Self-Forcing models. + + Requires a Causal-WAN compatible model (diffusion_model must expose + init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W]. """ extra_args = {} if extra_args is None else extra_args model_options = extra_args.get("model_options", {}) transformer_options = model_options.get("transformer_options", {}) ar_config = transformer_options.get("ar_config", {}) + if x.ndim != 5: + raise ValueError( + f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. " + "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)." + ) + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + + if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + raise TypeError( + "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " + "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " + "does not support this interface — choose a different sampler." + ) + num_frame_per_block = ar_config.get("num_frame_per_block", 1) seed = extra_args.get("seed", 0) bs, c, lat_t, lat_h, lat_w = x.shape frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division num_blocks = -(-lat_t // num_frame_per_block) # ceiling division - - inner_model = model.inner_model.inner_model - causal_model = inner_model.diffusion_model device = x.device model_dtype = inner_model.get_dtype() diff --git a/comfy/samplers.py b/comfy/samplers.py index 03a07ec68..6ee50181c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -719,6 +719,9 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma +# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents) +# but is kept here so it appears in standard sampler dropdowns; sample_ar_video +# validates at runtime and raises a clear error for incompatible checkpoints. KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",