From 4b2734889ce80a8317bde90364f40e02f1690388 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 20:39:37 +0100 Subject: [PATCH] Remove dedicated ARLoader. --- comfy/supported_models.py | 12 +++- comfy_extras/nodes_ar_video.py | 101 +-------------------------------- 2 files changed, 12 insertions(+), 101 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4c5159fbe..aa66e035f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1166,10 +1166,20 @@ class WAN21_T2V(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) class WAN21_CausalAR_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "causal_ar": True, + } + sampling_settings = { "shift": 5.0, } + def __init__(self, unet_config): + super().__init__(unet_config) + self.unet_config.pop("causal_ar", None) + def get_model(self, state_dict, prefix="", device=None): return model_base.WAN21_CausalAR(self, device=device) @@ -1743,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 41bed9414..be9f2eaec 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,112 +1,14 @@ """ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - - LoadARVideoModel: load original HF/training or pre-converted checkpoints - via the standard BaseModel + ModelPatcher pipeline + - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors """ import torch -import logging -import folder_paths from typing_extensions import override import comfy.model_management -import comfy.utils -import comfy.ops -import comfy.model_patcher -from comfy.ldm.wan.ar_convert import extract_state_dict -from comfy.supported_models import WAN21_CausalAR_T2V from comfy_api.latest import ComfyExtension, io -# ── Model size presets derived from Wan 2.1 configs ────────────────────────── -WAN_CONFIGS = { - # dim → (ffn_dim, num_heads, num_layers, text_dim) - 1536: (8960, 12, 30, 4096), # 1.3B - 2048: (8192, 16, 32, 4096), # ~2B - 5120: (13824, 40, 40, 4096), # 14B -} - - -class LoadARVideoModel(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="LoadARVideoModel", - category="loaders/video_models", - inputs=[ - io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), - io.Int.Input("num_frame_per_block", default=1, min=1, max=21), - ], - outputs=[ - io.Model.Output(display_name="MODEL"), - ], - ) - - @classmethod - def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput: - ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) - raw = comfy.utils.load_torch_file(ckpt_path) - sd = extract_state_dict(raw, use_ema=True) - del raw - - dim = sd["head.modulation"].shape[-1] - out_dim = sd["head.head.weight"].shape[0] // 4 - in_dim = sd["patch_embedding.weight"].shape[1] - num_layers = 0 - while f"blocks.{num_layers}.self_attn.q.weight" in sd: - num_layers += 1 - - if dim in WAN_CONFIGS: - ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim] - else: - num_heads = dim // 128 - ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] - text_dim = 4096 - logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}") - - cross_attn_norm = "blocks.0.norm3.weight" in sd - - unet_config = { - "image_model": "wan2.1", - "model_type": "t2v", - "dim": dim, - "ffn_dim": ffn_dim, - "num_heads": num_heads, - "num_layers": num_layers, - "in_dim": in_dim, - "out_dim": out_dim, - "text_dim": text_dim, - "cross_attn_norm": cross_attn_norm, - } - - model_config = WAN21_CausalAR_T2V(unet_config) - unet_dtype = comfy.model_management.unet_dtype( - model_params=comfy.utils.calculate_parameters(sd), - supported_dtypes=model_config.supported_inference_dtypes, - ) - manual_cast_dtype = comfy.model_management.unet_manual_cast( - unet_dtype, - comfy.model_management.get_torch_device(), - model_config.supported_inference_dtypes, - ) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - - model = model_config.get_model(sd, "") - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.unet_offload_device() - - model_patcher = comfy.model_patcher.ModelPatcher( - model, load_device=load_device, offload_device=offload_device, - ) - if not comfy.model_management.is_device_cpu(offload_device): - model.to(offload_device) - model.load_model_weights(sd, "") - - model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = { - "num_frame_per_block": num_frame_per_block, - } - - return io.NodeOutput(model_patcher) - class EmptyARVideoLatent(io.ComfyNode): @classmethod @@ -139,7 +41,6 @@ class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - LoadARVideoModel, EmptyARVideoLatent, ]