diff --git a/comfy/ldm/wan/causal_model.py b/comfy/ldm/wan/causal_model.py new file mode 100644 index 000000000..268f7ac34 --- /dev/null +++ b/comfy/ldm/wan/causal_model.py @@ -0,0 +1,392 @@ +""" +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for +autoregressive (frame-by-frame) video generation via Causal Forcing. + +Weight-compatible with the standard WanModel -- same layer names, same shapes. +The difference is purely in the forward pass: this model processes one temporal +block at a time and maintains a KV cache across blocks. + +Reference: https://github.com/thu-ml/Causal-Forcing +""" + +import math +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import ( + sinusoidal_embedding_1d, + WanT2VCrossAttention, + WAN_CROSSATTENTION_CLASSES, + Head, + MLPProj, + repeat_e, +) +import comfy.ldm.common_dit +import comfy.model_management + + +class CausalWanSelfAttention(nn.Module): + """Self-attention with KV cache support for autoregressive inference.""" + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, + eps=1e-6, operation_settings={}): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + + def forward(self, x, freqs, kv_cache=None, transformer_options={}): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) + v = self.v(x).view(b, s, n, d) + + if kv_cache is None: + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v.view(b, s, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + else: + end = kv_cache["end"].item() + new_end = end + s + + # Roped K and plain V go into cache + kv_cache["k"][:, end:new_end] = k + kv_cache["v"][:, end:new_end] = v + kv_cache["end"].fill_(new_end) + + x = optimized_attention( + q.view(b, s, n * d), + kv_cache["k"][:, :new_end].view(b, new_end, n * d), + kv_cache["v"][:, :new_end].view(b, new_end, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + + x = self.o(x) + return x + + +class CausalWanAttentionBlock(nn.Module): + """Transformer block with KV-cached self-attention and cross-attention caching.""" + + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.norm1 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) + self.self_attn = CausalWanSelfAttention(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) + self.norm3 = ops.LayerNorm(dim, eps, elementwise_affine=True, device=device, dtype=dtype) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( + dim, num_heads, (-1, -1), qk_norm, eps, operation_settings=operation_settings) + self.norm2 = ops.LayerNorm(dim, eps, elementwise_affine=False, device=device, dtype=dtype) + self.ffn = nn.Sequential( + ops.Linear(dim, ffn_dim, device=device, dtype=dtype), + nn.GELU(approximate='tanh'), + ops.Linear(ffn_dim, dim, device=device, dtype=dtype)) + + self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=device, dtype=dtype)) + + def forward(self, x, e, freqs, context, context_img_len=257, + kv_cache=None, crossattn_cache=None, transformer_options={}): + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + + # Self-attention with optional KV cache + x = x.contiguous() + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, kv_cache=kv_cache, transformer_options=transformer_options) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y + + # Cross-attention with optional caching + if crossattn_cache is not None and crossattn_cache.get("is_init"): + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) + x_ca = optimized_attention( + q, crossattn_cache["k"], crossattn_cache["v"], + heads=self.num_heads, transformer_options=transformer_options) + x = x + self.cross_attn.o(x_ca) + else: + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if crossattn_cache is not None: + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) + crossattn_cache["v"] = self.cross_attn.v(context) + crossattn_cache["is_init"] = True + + # FFN + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class CausalWanModel(torch.nn.Module): + """ + Wan 2.1 diffusion backbone with causal KV-cache support. + + Same weight structure as WanModel -- loads identical state dicts. + Adds forward_block() for frame-by-frame autoregressive inference. + """ + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + image_model=None, + device=None, + dtype=None, + operations=None): + super().__init__() + self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.model_type = model_type + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + self.patch_embedding = operations.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size, + device=device, dtype=dtype) + self.text_embedding = nn.Sequential( + operations.Linear(text_dim, dim, device=device, dtype=dtype), + nn.GELU(approximate='tanh'), + operations.Linear(dim, dim, device=device, dtype=dtype)) + self.time_embedding = nn.Sequential( + operations.Linear(freq_dim, dim, device=device, dtype=dtype), + nn.SiLU(), + operations.Linear(dim, dim, device=device, dtype=dtype)) + self.time_projection = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, dim * 6, device=device, dtype=dtype)) + + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + CausalWanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + for _ in range(num_layers) + ]) + + self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) + + d = dim // num_heads + self.rope_embedder = EmbedND( + dim=d, theta=10000.0, + axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + else: + self.img_emb = None + + self.ref_conv = None + + def rope_encode(self, t, h, w, t_start=0, device=None, dtype=None): + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] += torch.linspace( + t_start, t_start + (t_len - 1), steps=t_len, device=device, dtype=dtype + ).reshape(-1, 1, 1) + img_ids[:, :, :, 1] += torch.linspace( + 0, h_len - 1, steps=h_len, device=device, dtype=dtype + ).reshape(1, -1, 1) + img_ids[:, :, :, 2] += torch.linspace( + 0, w_len - 1, steps=w_len, device=device, dtype=dtype + ).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + return self.rope_embedder(img_ids).movedim(1, 2) + + def forward_block(self, x, timestep, context, start_frame, + kv_caches, crossattn_caches, clip_fea=None): + """ + Forward one temporal block for autoregressive inference. + + Args: + x: [B, C, block_frames, H, W] input latent for the current block + timestep: [B, block_frames] per-frame timesteps + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) + start_frame: temporal frame index for RoPE offset + kv_caches: list of per-layer KV cache dicts + crossattn_caches: list of per-layer cross-attention cache dicts + clip_fea: optional CLIP features for I2V + + Returns: + flow_pred: [B, C_out, block_frames, H, W] flow prediction + """ + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + bs, c, t, h, w = x.shape + + x = self.patch_embedding(x) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # Per-frame time embedding → [B, block_frames, 6, dim] + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # Text embedding (reuses crossattn_cache after first block) + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + # RoPE for current block's temporal position + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) + + # Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + kv_cache=kv_caches[i], + crossattn_cache=crossattn_caches[i]) + + # Head + x = self.head(x, e) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + b = x.shape[0] + u = x[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c) + u = torch.einsum('bfhwpqrc->bcfphqwr', u) + u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + return u + + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): + """Create fresh KV caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({ + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "end": torch.tensor([0], dtype=torch.long, device=device), + }) + return caches + + def init_crossattn_caches(self, batch_size, device, dtype): + """Create fresh cross-attention caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({"is_init": False}) + return caches + + def reset_kv_caches(self, kv_caches): + """Reset KV caches to empty (reuse allocated memory).""" + for cache in kv_caches: + cache["end"].fill_(0) + + def reset_crossattn_caches(self, crossattn_caches): + """Reset cross-attention caches.""" + for cache in crossattn_caches: + cache["is_init"] = False + + @property + def head_dim(self): + return self.dim // self.num_heads + + # Standard forward for non-causal use (compatibility with ComfyUI infrastructure) + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + x = self.patch_embedding(x) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) + + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + for block in self.blocks: + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + transformer_options=transformer_options) + + x = self.head(x, e) + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] diff --git a/comfy_extras/nodes_causal_forcing.py b/comfy_extras/nodes_causal_forcing.py new file mode 100644 index 000000000..3646cec37 --- /dev/null +++ b/comfy_extras/nodes_causal_forcing.py @@ -0,0 +1,272 @@ +""" +ComfyUI nodes for Causal Forcing autoregressive video generation. + - LoadCausalForcingModel: load original HF/training or pre-converted checkpoints + (auto-detects format and converts state dict at runtime) + - CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache +""" + +import torch +import logging +import folder_paths +from typing_extensions import override + +import comfy.model_management +import comfy.utils +import comfy.ops +import comfy.latent_formats +from comfy.model_patcher import ModelPatcher +from comfy.ldm.wan.causal_model import CausalWanModel +from comfy.ldm.wan.causal_convert import extract_state_dict +from comfy_api.latest import ComfyExtension, io + +# ── Model size presets derived from Wan 2.1 configs ────────────────────────── +WAN_CONFIGS = { + # dim → (ffn_dim, num_heads, num_layers, text_dim) + 1536: (8960, 12, 30, 4096), # 1.3B + 2048: (8192, 16, 32, 4096), # ~2B + 5120: (13824, 40, 40, 4096), # 14B +} + + +class LoadCausalForcingModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadCausalForcingModel", + category="loaders/video_models", + inputs=[ + io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), + ], + outputs=[ + io.Model.Output(display_name="MODEL"), + ], + ) + + @classmethod + def execute(cls, ckpt_name) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) + raw = comfy.utils.load_torch_file(ckpt_path) + sd = extract_state_dict(raw, use_ema=True) + del raw + + dim = sd["head.modulation"].shape[-1] + out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim + in_dim = sd["patch_embedding.weight"].shape[1] + num_layers = 0 + while f"blocks.{num_layers}.self_attn.q.weight" in sd: + num_layers += 1 + + if dim in WAN_CONFIGS: + ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim] + else: + num_heads = dim // 128 + ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] + text_dim = 4096 + logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") + + cross_attn_norm = "blocks.0.norm3.weight" in sd + + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + ops = comfy.ops.disable_weight_init + + model = CausalWanModel( + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=in_dim, + dim=dim, + ffn_dim=ffn_dim, + freq_dim=256, + text_dim=text_dim, + out_dim=out_dim, + num_heads=num_heads, + num_layers=num_layers, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=cross_attn_norm, + eps=1e-6, + device=offload_device, + dtype=torch.bfloat16, + operations=ops, + ) + + model.load_state_dict(sd, strict=False) + model.eval() + + model_size = comfy.model_management.module_size(model) + patcher = ModelPatcher(model, load_device=load_device, + offload_device=offload_device, size=model_size) + patcher.model.latent_format = comfy.latent_formats.Wan21() + return io.NodeOutput(patcher) + + +class CausalForcingSampler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CausalForcingSampler", + category="sampling", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("num_frames", default=81, min=1, max=1024, step=4), + io.Int.Input("num_frame_per_block", default=1, min=1, max=21), + io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1), + io.String.Input("denoising_steps", default="1000,750,500,250"), + ], + outputs=[ + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, model, positive, seed, width, height, + num_frames, num_frame_per_block, timestep_shift, + denoising_steps) -> io.NodeOutput: + + device = comfy.model_management.get_torch_device() + + # Parse denoising steps + step_values = [int(s.strip()) for s in denoising_steps.split(",")] + + # Build scheduler sigmas (FlowMatch with shift) + num_train_timesteps = 1000 + raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1] + sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas) + timesteps = sigmas * num_train_timesteps + + # Warp denoising step indices to actual timestep values + all_timesteps = torch.cat([timesteps, torch.tensor([0.0])]) + warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)] + + # Get the CausalWanModel from the patcher + comfy.model_management.load_model_gpu(model) + causal_model = model.model + dtype = torch.bfloat16 + + # Extract text embeddings from conditioning + cond = positive[0][0].to(device=device, dtype=dtype) + if cond.ndim == 2: + cond = cond.unsqueeze(0) + + # Latent dimensions + lat_h = height // 8 + lat_w = width // 8 + lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression + in_channels = 16 + + # Generate noise + generator = torch.Generator(device="cpu").manual_seed(seed) + noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w, + generator=generator, device="cpu").to(device=device, dtype=dtype) + + assert lat_t % num_frame_per_block == 0, \ + f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})" + num_blocks = lat_t // num_frame_per_block + + # Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch + frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2) + max_seq_len = lat_t * frame_seq_len + + # Initialize caches + kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype) + crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype) + + output = torch.zeros_like(noise) + pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks) + + current_start_frame = 0 + for block_idx in range(num_blocks): + block_frames = num_frame_per_block + frame_start = current_start_frame + frame_end = current_start_frame + block_frames + + # Noise slice for this block: [B, C, block_frames, H, W] + noisy_input = noise[:, :, frame_start:frame_end] + + # Denoising loop (e.g. 4 steps) + for step_idx, current_timestep in enumerate(warped_steps): + t_val = current_timestep.item() + + # Per-frame timestep tensor [B, block_frames] + timestep_tensor = torch.full( + (1, block_frames), t_val, device=device, dtype=dtype) + + # Model forward + flow_pred = causal_model.forward_block( + x=noisy_input, + timestep=timestep_tensor, + context=cond, + start_frame=current_start_frame, + kv_caches=kv_caches, + crossattn_caches=crossattn_caches, + ) + + # x0 = input - sigma * flow_pred + sigma_t = _lookup_sigma(sigmas, timesteps, t_val) + denoised = noisy_input - sigma_t * flow_pred + + if step_idx < len(warped_steps) - 1: + # Add noise for next step + next_t = warped_steps[step_idx + 1].item() + sigma_next = _lookup_sigma(sigmas, timesteps, next_t) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + # Roll back KV cache end pointer so next step re-writes same positions + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) + else: + noisy_input = denoised + + pbar.update(1) + + output[:, :, frame_start:frame_end] = noisy_input + + # Cache update: forward at t=0 with clean output to fill KV cache + with torch.no_grad(): + # Reset cache end to before this block so the t=0 pass writes clean K/V + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) + + t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype) + causal_model.forward_block( + x=noisy_input, + timestep=t_zero, + context=cond, + start_frame=current_start_frame, + kv_caches=kv_caches, + crossattn_caches=crossattn_caches, + ) + + pbar.update(1) + current_start_frame += block_frames + + # Apply latent format scaling + latent_format = comfy.latent_formats.Wan21() + output_scaled = latent_format.process_in(output.float().cpu()) + + return io.NodeOutput({"samples": output_scaled}) + + +def _lookup_sigma(sigmas, timesteps, t_val): + """Find the sigma corresponding to a timestep value.""" + idx = torch.argmin((timesteps - t_val).abs()).item() + return sigmas[idx] + + +class CausalForcingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadCausalForcingModel, + CausalForcingSampler, + ] + + +async def comfy_entrypoint() -> CausalForcingExtension: + return CausalForcingExtension() diff --git a/nodes.py b/nodes.py index 37ceac2fc..66528c24d 100644 --- a/nodes.py +++ b/nodes.py @@ -2443,6 +2443,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_causal_forcing.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py",