mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Remove dedicated ARLoader.
This commit is contained in:
parent
e649a3bc72
commit
4b2734889c
@ -1166,10 +1166,20 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"causal_ar": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.unet_config.pop("causal_ar", None)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.WAN21_CausalAR(self, device=device)
|
||||
|
||||
@ -1743,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -1,112 +1,14 @@
|
||||
"""
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
|
||||
via the standard BaseModel + ModelPatcher pipeline
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
from comfy.ldm.wan.ar_convert import extract_state_dict
|
||||
from comfy.supported_models import WAN21_CausalAR_T2V
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
||||
WAN_CONFIGS = {
|
||||
# dim → (ffn_dim, num_heads, num_layers, text_dim)
|
||||
1536: (8960, 12, 30, 4096), # 1.3B
|
||||
2048: (8192, 16, 32, 4096), # ~2B
|
||||
5120: (13824, 40, 40, 4096), # 14B
|
||||
}
|
||||
|
||||
|
||||
class LoadARVideoModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadARVideoModel",
|
||||
category="loaders/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
||||
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="MODEL"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
|
||||
raw = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd = extract_state_dict(raw, use_ema=True)
|
||||
del raw
|
||||
|
||||
dim = sd["head.modulation"].shape[-1]
|
||||
out_dim = sd["head.head.weight"].shape[0] // 4
|
||||
in_dim = sd["patch_embedding.weight"].shape[1]
|
||||
num_layers = 0
|
||||
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
|
||||
num_layers += 1
|
||||
|
||||
if dim in WAN_CONFIGS:
|
||||
ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim]
|
||||
else:
|
||||
num_heads = dim // 128
|
||||
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
||||
text_dim = 4096
|
||||
logging.warning(f"ARVideo: unknown dim={dim}, inferring num_heads={num_heads}, ffn_dim={ffn_dim}")
|
||||
|
||||
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
||||
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"dim": dim,
|
||||
"ffn_dim": ffn_dim,
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
"in_dim": in_dim,
|
||||
"out_dim": out_dim,
|
||||
"text_dim": text_dim,
|
||||
"cross_attn_norm": cross_attn_norm,
|
||||
}
|
||||
|
||||
model_config = WAN21_CausalAR_T2V(unet_config)
|
||||
unet_dtype = comfy.model_management.unet_dtype(
|
||||
model_params=comfy.utils.calculate_parameters(sd),
|
||||
supported_dtypes=model_config.supported_inference_dtypes,
|
||||
)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(
|
||||
unet_dtype,
|
||||
comfy.model_management.get_torch_device(),
|
||||
model_config.supported_inference_dtypes,
|
||||
)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
model = model_config.get_model(sd, "")
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(
|
||||
model, load_device=load_device, offload_device=offload_device,
|
||||
)
|
||||
if not comfy.model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(sd, "")
|
||||
|
||||
model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = {
|
||||
"num_frame_per_block": num_frame_per_block,
|
||||
}
|
||||
|
||||
return io.NodeOutput(model_patcher)
|
||||
|
||||
|
||||
class EmptyARVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -139,7 +41,6 @@ class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadARVideoModel,
|
||||
EmptyARVideoLatent,
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user