mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Rename causual forcing to using more general auto regressive naming convention.
This commit is contained in:
parent
0836390c27
commit
2f30a821c5
70
comfy/ldm/wan/ar_convert.py
Normal file
70
comfy/ldm/wan/ar_convert.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
State dict conversion for Causal Forcing checkpoints.
|
||||||
|
|
||||||
|
Handles three checkpoint layouts:
|
||||||
|
1. Training checkpoint with top-level generator_ema / generator keys
|
||||||
|
2. Already-flattened state dict with model.* prefixes
|
||||||
|
3. Already-converted ComfyUI state dict (bare model keys)
|
||||||
|
|
||||||
|
Strips prefixes so the result matches the standard Wan 2.1 / CausalWanModel key layout
|
||||||
|
(e.g. blocks.0.self_attn.q.weight, head.modulation, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
|
||||||
|
|
||||||
|
_MODEL_KEY_PREFIXES = (
|
||||||
|
"blocks.", "head.", "patch_embedding.", "text_embedding.",
|
||||||
|
"time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
|
||||||
|
"""
|
||||||
|
Extract and clean a Causal Forcing state dict from a training checkpoint.
|
||||||
|
|
||||||
|
Returns a state dict with keys matching the CausalWanModel / WanModel layout.
|
||||||
|
"""
|
||||||
|
# Case 3: already converted -- keys are bare model keys
|
||||||
|
if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
# Case 1: training checkpoint with wrapper key
|
||||||
|
raw_sd = None
|
||||||
|
order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
|
||||||
|
for wrapper_key in order:
|
||||||
|
if wrapper_key in state_dict:
|
||||||
|
raw_sd = state_dict[wrapper_key]
|
||||||
|
log.info("Causal Forcing: extracted '%s' with %d keys", wrapper_key, len(raw_sd))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Case 2: flat dict with model.* prefixes
|
||||||
|
if raw_sd is None:
|
||||||
|
if any(k.startswith("model.") for k in state_dict):
|
||||||
|
raw_sd = state_dict
|
||||||
|
else:
|
||||||
|
raise KeyError(
|
||||||
|
f"Cannot detect Causal Forcing checkpoint layout. "
|
||||||
|
f"Top-level keys: {list(state_dict.keys())[:20]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
for k, v in raw_sd.items():
|
||||||
|
new_k = k
|
||||||
|
for prefix in PREFIXES_TO_STRIP:
|
||||||
|
if new_k.startswith(prefix):
|
||||||
|
new_k = new_k[len(prefix):]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if not new_k.startswith(_MODEL_KEY_PREFIXES):
|
||||||
|
log.debug("Causal Forcing: skipping non-model key: %s", k)
|
||||||
|
continue
|
||||||
|
out_sd[new_k] = v
|
||||||
|
|
||||||
|
if "head.modulation" not in out_sd:
|
||||||
|
raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
|
||||||
|
|
||||||
|
return out_sd
|
||||||
@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
ComfyUI nodes for Causal Forcing autoregressive video generation.
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
- LoadCausalForcingModel: load original HF/training or pre-converted checkpoints
|
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
|
||||||
(auto-detects format and converts state dict at runtime)
|
(auto-detects format and converts state dict at runtime)
|
||||||
- CausalForcingSampler: autoregressive frame-by-frame sampling with KV cache
|
- ARVideoSampler: autoregressive frame-by-frame sampling with KV cache
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -15,8 +15,8 @@ import comfy.utils
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.ldm.wan.causal_model import CausalWanModel
|
from comfy.ldm.wan.ar_model import CausalWanModel
|
||||||
from comfy.ldm.wan.causal_convert import extract_state_dict
|
from comfy.ldm.wan.ar_convert import extract_state_dict
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
||||||
@ -28,11 +28,11 @@ WAN_CONFIGS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LoadCausalForcingModel(io.ComfyNode):
|
class LoadARVideoModel(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="LoadCausalForcingModel",
|
node_id="LoadARVideoModel",
|
||||||
category="loaders/video_models",
|
category="loaders/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
||||||
@ -62,7 +62,7 @@ class LoadCausalForcingModel(io.ComfyNode):
|
|||||||
num_heads = dim // 128
|
num_heads = dim // 128
|
||||||
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
||||||
text_dim = 4096
|
text_dim = 4096
|
||||||
logging.warning(f"CausalForcing: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
|
logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
|
||||||
|
|
||||||
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
||||||
|
|
||||||
@ -101,11 +101,11 @@ class LoadCausalForcingModel(io.ComfyNode):
|
|||||||
return io.NodeOutput(patcher)
|
return io.NodeOutput(patcher)
|
||||||
|
|
||||||
|
|
||||||
class CausalForcingSampler(io.ComfyNode):
|
class ARVideoSampler(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CausalForcingSampler",
|
node_id="ARVideoSampler",
|
||||||
category="sampling",
|
category="sampling",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
@ -258,14 +258,14 @@ def _lookup_sigma(sigmas, timesteps, t_val):
|
|||||||
return sigmas[idx]
|
return sigmas[idx]
|
||||||
|
|
||||||
|
|
||||||
class CausalForcingExtension(ComfyExtension):
|
class ARVideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
LoadCausalForcingModel,
|
LoadARVideoModel,
|
||||||
CausalForcingSampler,
|
ARVideoSampler,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> CausalForcingExtension:
|
async def comfy_entrypoint() -> ARVideoExtension:
|
||||||
return CausalForcingExtension()
|
return ARVideoExtension()
|
||||||
2
nodes.py
2
nodes.py
@ -2443,7 +2443,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
"nodes_wanmove.py",
|
||||||
"nodes_causal_forcing.py",
|
"nodes_ar_video.py",
|
||||||
"nodes_image_compare.py",
|
"nodes_image_compare.py",
|
||||||
"nodes_zimage.py",
|
"nodes_zimage.py",
|
||||||
"nodes_glsl.py",
|
"nodes_glsl.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user