mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
"""
|
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
|
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
|
|
via the standard BaseModel + ModelPatcher pipeline
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import folder_paths
|
|
from typing_extensions import override
|
|
|
|
import comfy.model_management
|
|
import comfy.utils
|
|
import comfy.ops
|
|
import comfy.model_patcher
|
|
from comfy.ldm.wan.ar_convert import extract_state_dict
|
|
from comfy.supported_models import WAN21_CausalAR_T2V
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
|
WAN_CONFIGS = {
|
|
# dim → (ffn_dim, num_heads, num_layers, text_dim)
|
|
1536: (8960, 12, 30, 4096), # 1.3B
|
|
2048: (8192, 16, 32, 4096), # ~2B
|
|
5120: (13824, 40, 40, 4096), # 14B
|
|
}
|
|
|
|
|
|
class LoadARVideoModel(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoadARVideoModel",
|
|
category="loaders/video_models",
|
|
inputs=[
|
|
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
|
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(display_name="MODEL"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput:
|
|
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
|
|
raw = comfy.utils.load_torch_file(ckpt_path)
|
|
sd = extract_state_dict(raw, use_ema=True)
|
|
del raw
|
|
|
|
dim = sd["head.modulation"].shape[-1]
|
|
out_dim = sd["head.head.weight"].shape[0] // 4
|
|
in_dim = sd["patch_embedding.weight"].shape[1]
|
|
num_layers = 0
|
|
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
|
|
num_layers += 1
|
|
|
|
if dim in WAN_CONFIGS:
|
|
ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim]
|
|
else:
|
|
num_heads = dim // 128
|
|
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
|
text_dim = 4096
|
|
logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
|
|
|
|
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
|
|
|
unet_config = {
|
|
"image_model": "wan2.1",
|
|
"model_type": "t2v",
|
|
"dim": dim,
|
|
"ffn_dim": ffn_dim,
|
|
"num_heads": num_heads,
|
|
"num_layers": num_layers,
|
|
"in_dim": in_dim,
|
|
"out_dim": out_dim,
|
|
"text_dim": text_dim,
|
|
"cross_attn_norm": cross_attn_norm,
|
|
}
|
|
|
|
model_config = WAN21_CausalAR_T2V(unet_config)
|
|
unet_dtype = comfy.model_management.unet_dtype(
|
|
model_params=comfy.utils.calculate_parameters(sd),
|
|
supported_dtypes=model_config.supported_inference_dtypes,
|
|
)
|
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(
|
|
unet_dtype,
|
|
comfy.model_management.get_torch_device(),
|
|
model_config.supported_inference_dtypes,
|
|
)
|
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
|
|
|
model = model_config.get_model(sd, "")
|
|
load_device = comfy.model_management.get_torch_device()
|
|
offload_device = comfy.model_management.unet_offload_device()
|
|
|
|
model_patcher = comfy.model_patcher.ModelPatcher(
|
|
model, load_device=load_device, offload_device=offload_device,
|
|
)
|
|
if not comfy.model_management.is_device_cpu(offload_device):
|
|
model.to(offload_device)
|
|
model.load_model_weights(sd, "")
|
|
|
|
model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = {
|
|
"num_frame_per_block": num_frame_per_block,
|
|
}
|
|
|
|
return io.NodeOutput(model_patcher)
|
|
|
|
|
|
class EmptyARVideoLatent(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="EmptyARVideoLatent",
|
|
category="latent/video",
|
|
inputs=[
|
|
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
|
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
|
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
|
io.Int.Input("batch_size", default=1, min=1, max=64),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(display_name="LATENT"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
|
lat_t = ((length - 1) // 4) + 1
|
|
latent = torch.zeros(
|
|
[batch_size, 16, lat_t, height // 8, width // 8],
|
|
device=comfy.model_management.intermediate_device(),
|
|
)
|
|
return io.NodeOutput({"samples": latent})
|
|
|
|
|
|
class ARVideoExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
LoadARVideoModel,
|
|
EmptyARVideoLatent,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> ARVideoExtension:
|
|
return ARVideoExtension()
|