diff --git a/comfy/ldm/wan/ar_convert.py b/comfy/ldm/wan/ar_convert.py new file mode 100644 index 000000000..bec5c1014 --- /dev/null +++ b/comfy/ldm/wan/ar_convert.py @@ -0,0 +1,70 @@ +""" +State dict conversion for Causal Forcing checkpoints. + +Handles three checkpoint layouts: + 1. Training checkpoint with top-level generator_ema / generator keys + 2. Already-flattened state dict with model.* prefixes + 3. Already-converted ComfyUI state dict (bare model keys) + +Strips prefixes so the result matches the standard Wan 2.1 / CausalWanModel key layout +(e.g. blocks.0.self_attn.q.weight, head.modulation, etc.) +""" + +import logging + +log = logging.getLogger(__name__) + +PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."] + +_MODEL_KEY_PREFIXES = ( + "blocks.", "head.", "patch_embedding.", "text_embedding.", + "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.", +) + + +def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict: + """ + Extract and clean a Causal Forcing state dict from a training checkpoint. + + Returns a state dict with keys matching the CausalWanModel / WanModel layout. + """ + # Case 3: already converted -- keys are bare model keys + if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict: + return state_dict + + # Case 1: training checkpoint with wrapper key + raw_sd = None + order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"] + for wrapper_key in order: + if wrapper_key in state_dict: + raw_sd = state_dict[wrapper_key] + log.info("Causal Forcing: extracted '%s' with %d keys", wrapper_key, len(raw_sd)) + break + + # Case 2: flat dict with model.* prefixes + if raw_sd is None: + if any(k.startswith("model.") for k in state_dict): + raw_sd = state_dict + else: + raise KeyError( + f"Cannot detect Causal Forcing checkpoint layout. " + f"Top-level keys: {list(state_dict.keys())[:20]}" + ) + + out_sd = {} + for k, v in raw_sd.items(): + new_k = k + for prefix in PREFIXES_TO_STRIP: + if new_k.startswith(prefix): + new_k = new_k[len(prefix):] + break + else: + if not new_k.startswith(_MODEL_KEY_PREFIXES): + log.debug("Causal Forcing: skipping non-model key: %s", k) + continue + out_sd[new_k] = v + + if "head.modulation" not in out_sd: + raise ValueError("Conversion failed: 'head.modulation' not found in output keys") + + return out_sd diff --git a/comfy/ldm/wan/causal_model.py b/comfy/ldm/wan/ar_model.py similarity index 100% rename from comfy/ldm/wan/causal_model.py rename to comfy/ldm/wan/ar_model.py diff --git a/comfy_extras/nodes_causal_forcing.py b/comfy_extras/nodes_ar_video.py similarity index 92% rename from comfy_extras/nodes_causal_forcing.py rename to comfy_extras/nodes_ar_video.py index 23c7049a4..08010a6ac 100644 --- a/comfy_extras/nodes_causal_forcing.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,8 +1,8 @@ """ -ComfyUI nodes for Causal Forcing autoregressive video generation. - - LoadCausalForcingModel: load original HF/training or pre-converted checkpoints +ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). + - LoadARVideoModel: load original HF/training or pre-converted checkpoints (auto-detects format and converts state dict at runtime) - - CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache + - ARVideoSampler: autoregressive frame-by-frame sampling with KV cache """ import torch @@ -15,8 +15,8 @@ import comfy.utils import comfy.ops import comfy.latent_formats from comfy.model_patcher import ModelPatcher -from comfy.ldm.wan.causal_model import CausalWanModel -from comfy.ldm.wan.causal_convert import extract_state_dict +from comfy.ldm.wan.ar_model import CausalWanModel +from comfy.ldm.wan.ar_convert import extract_state_dict from comfy_api.latest import ComfyExtension, io # ── Model size presets derived from Wan 2.1 configs ────────────────────────── @@ -28,11 +28,11 @@ WAN_CONFIGS = { } -class LoadCausalForcingModel(io.ComfyNode): +class LoadARVideoModel(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="LoadCausalForcingModel", + node_id="LoadARVideoModel", category="loaders/video_models", inputs=[ io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), @@ -62,7 +62,7 @@ class LoadCausalForcingModel(io.ComfyNode): num_heads = dim // 128 ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] text_dim = 4096 - logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") + logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") cross_attn_norm = "blocks.0.norm3.weight" in sd @@ -101,11 +101,11 @@ class LoadCausalForcingModel(io.ComfyNode): return io.NodeOutput(patcher) -class CausalForcingSampler(io.ComfyNode): +class ARVideoSampler(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="CausalForcingSampler", + node_id="ARVideoSampler", category="sampling", inputs=[ io.Model.Input("model"), @@ -258,14 +258,14 @@ def _lookup_sigma(sigmas, timesteps, t_val): return sigmas[idx] -class CausalForcingExtension(ComfyExtension): +class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - LoadCausalForcingModel, - CausalForcingSampler, + LoadARVideoModel, + ARVideoSampler, ] -async def comfy_entrypoint() -> CausalForcingExtension: - return CausalForcingExtension() +async def comfy_entrypoint() -> ARVideoExtension: + return ARVideoExtension() diff --git a/nodes.py b/nodes.py index 66528c24d..4d674617f 100644 --- a/nodes.py +++ b/nodes.py @@ -2443,7 +2443,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", - "nodes_causal_forcing.py", + "nodes_ar_video.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py",