Rename causual forcing to using more general auto regressive naming convention.

This commit is contained in:
Talmaj Marinc 2026-03-20 21:05:23 +01:00
parent 0836390c27
commit 2f30a821c5
4 changed files with 86 additions and 16 deletions

View File

@ -0,0 +1,70 @@
"""
State dict conversion for Causal Forcing checkpoints.
Handles three checkpoint layouts:
1. Training checkpoint with top-level generator_ema / generator keys
2. Already-flattened state dict with model.* prefixes
3. Already-converted ComfyUI state dict (bare model keys)
Strips prefixes so the result matches the standard Wan 2.1 / CausalWanModel key layout
(e.g. blocks.0.self_attn.q.weight, head.modulation, etc.)
"""
import logging
log = logging.getLogger(__name__)
PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
_MODEL_KEY_PREFIXES = (
"blocks.", "head.", "patch_embedding.", "text_embedding.",
"time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
)
def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
"""
Extract and clean a Causal Forcing state dict from a training checkpoint.
Returns a state dict with keys matching the CausalWanModel / WanModel layout.
"""
# Case 3: already converted -- keys are bare model keys
if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
return state_dict
# Case 1: training checkpoint with wrapper key
raw_sd = None
order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
for wrapper_key in order:
if wrapper_key in state_dict:
raw_sd = state_dict[wrapper_key]
log.info("Causal Forcing: extracted '%s' with %d keys", wrapper_key, len(raw_sd))
break
# Case 2: flat dict with model.* prefixes
if raw_sd is None:
if any(k.startswith("model.") for k in state_dict):
raw_sd = state_dict
else:
raise KeyError(
f"Cannot detect Causal Forcing checkpoint layout. "
f"Top-level keys: {list(state_dict.keys())[:20]}"
)
out_sd = {}
for k, v in raw_sd.items():
new_k = k
for prefix in PREFIXES_TO_STRIP:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
break
else:
if not new_k.startswith(_MODEL_KEY_PREFIXES):
log.debug("Causal Forcing: skipping non-model key: %s", k)
continue
out_sd[new_k] = v
if "head.modulation" not in out_sd:
raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
return out_sd

View File

@ -1,8 +1,8 @@
"""
ComfyUI nodes for Causal Forcing autoregressive video generation.
- LoadCausalForcingModel: load original HF/training or pre-converted checkpoints
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
(auto-detects format and converts state dict at runtime)
- CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache
- ARVideoSampler: autoregressive frame-by-frame sampling with KV cache
"""
import torch
@ -15,8 +15,8 @@ import comfy.utils
import comfy.ops
import comfy.latent_formats
from comfy.model_patcher import ModelPatcher
from comfy.ldm.wan.causal_model import CausalWanModel
from comfy.ldm.wan.causal_convert import extract_state_dict
from comfy.ldm.wan.ar_model import CausalWanModel
from comfy.ldm.wan.ar_convert import extract_state_dict
from comfy_api.latest import ComfyExtension, io
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
@ -28,11 +28,11 @@ WAN_CONFIGS = {
}
class LoadCausalForcingModel(io.ComfyNode):
class LoadARVideoModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadCausalForcingModel",
node_id="LoadARVideoModel",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
@ -62,7 +62,7 @@ class LoadCausalForcingModel(io.ComfyNode):
num_heads = dim // 128
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
text_dim = 4096
logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
cross_attn_norm = "blocks.0.norm3.weight" in sd
@ -101,11 +101,11 @@ class LoadCausalForcingModel(io.ComfyNode):
return io.NodeOutput(patcher)
class CausalForcingSampler(io.ComfyNode):
class ARVideoSampler(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CausalForcingSampler",
node_id="ARVideoSampler",
category="sampling",
inputs=[
io.Model.Input("model"),
@ -258,14 +258,14 @@ def _lookup_sigma(sigmas, timesteps, t_val):
return sigmas[idx]
class CausalForcingExtension(ComfyExtension):
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadCausalForcingModel,
CausalForcingSampler,
LoadARVideoModel,
ARVideoSampler,
]
async def comfy_entrypoint() -> CausalForcingExtension:
return CausalForcingExtension()
async def comfy_entrypoint() -> ARVideoExtension:
return ARVideoExtension()

View File

@ -2443,7 +2443,7 @@ async def init_builtin_extra_nodes():
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_causal_forcing.py",
"nodes_ar_video.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",