Remove dedicated ARLoader.

This commit is contained in:
Talmaj Marinc 2026-03-25 20:39:37 +01:00
parent e649a3bc72
commit 4b2734889c
2 changed files with 12 additions and 101 deletions

View File

@ -1166,10 +1166,20 @@ class WAN21_T2V(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_CausalAR_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"causal_ar": True,
}
sampling_settings = {
"shift": 5.0,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.unet_config.pop("causal_ar", None)
def get_model(self, state_dict, prefix="", device=None):
return model_base.WAN21_CausalAR(self, device=device)
@ -1743,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@ -1,112 +1,14 @@
"""
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
via the standard BaseModel + ModelPatcher pipeline
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
"""
import torch
import logging
import folder_paths
from typing_extensions import override
import comfy.model_management
import comfy.utils
import comfy.ops
import comfy.model_patcher
from comfy.ldm.wan.ar_convert import extract_state_dict
from comfy.supported_models import WAN21_CausalAR_T2V
from comfy_api.latest import ComfyExtension, io
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
WAN_CONFIGS = {
# dim → (ffn_dim, num_heads, num_layers, text_dim)
1536: (8960, 12, 30, 4096), # 1.3B
2048: (8192, 16, 32, 4096), # ~2B
5120: (13824, 40, 40, 4096), # 14B
}
class LoadARVideoModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadARVideoModel",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
],
outputs=[
io.Model.Output(display_name="MODEL"),
],
)
@classmethod
def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
raw = comfy.utils.load_torch_file(ckpt_path)
sd = extract_state_dict(raw, use_ema=True)
del raw
dim = sd["head.modulation"].shape[-1]
out_dim = sd["head.head.weight"].shape[0] // 4
in_dim = sd["patch_embedding.weight"].shape[1]
num_layers = 0
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
num_layers += 1
if dim in WAN_CONFIGS:
ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim]
else:
num_heads = dim // 128
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
text_dim = 4096
logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
cross_attn_norm = "blocks.0.norm3.weight" in sd
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"dim": dim,
"ffn_dim": ffn_dim,
"num_heads": num_heads,
"num_layers": num_layers,
"in_dim": in_dim,
"out_dim": out_dim,
"text_dim": text_dim,
"cross_attn_norm": cross_attn_norm,
}
model_config = WAN21_CausalAR_T2V(unet_config)
unet_dtype = comfy.model_management.unet_dtype(
model_params=comfy.utils.calculate_parameters(sd),
supported_dtypes=model_config.supported_inference_dtypes,
)
manual_cast_dtype = comfy.model_management.unet_manual_cast(
unet_dtype,
comfy.model_management.get_torch_device(),
model_config.supported_inference_dtypes,
)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(sd, "")
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
model_patcher = comfy.model_patcher.ModelPatcher(
model, load_device=load_device, offload_device=offload_device,
)
if not comfy.model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(sd, "")
model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = {
"num_frame_per_block": num_frame_per_block,
}
return io.NodeOutput(model_patcher)
class EmptyARVideoLatent(io.ComfyNode):
@classmethod
@ -139,7 +41,6 @@ class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadARVideoModel,
EmptyARVideoLatent,
]