diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..b1a8f80ab 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1810,3 +1810,100 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + + +@torch.no_grad() +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None): + """ + Autoregressive video sampler: block-by-block denoising with KV cache + and flow-match re-noising for Causal Forcing / Self-Forcing models. + + Requires a Causal-WAN compatible model (diffusion_model must expose + init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W]. + """ + extra_args = {} if extra_args is None else extra_args + model_options = extra_args.get("model_options", {}) + transformer_options = model_options.get("transformer_options", {}) + ar_config = transformer_options.get("ar_config", {}) + + if x.ndim != 5: + raise ValueError( + f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. " + "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)." + ) + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + + if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + raise TypeError( + "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " + "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " + "does not support this interface — choose a different sampler." + ) + + num_frame_per_block = ar_config.get("num_frame_per_block", 1) + seed = extra_args.get("seed", 0) + + bs, c, lat_t, lat_h, lat_w = x.shape + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + device = x.device + model_dtype = inner_model.get_dtype() + + kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) + + output = torch.zeros_like(x) + s_in = x.new_ones([x.shape[0]]) + current_start_frame = 0 + num_sigma_steps = len(sigmas) - 1 + total_real_steps = num_blocks * num_sigma_steps + step_count = 0 + + try: + for block_idx in trange(num_blocks, disable=disable): + bf = min(num_frame_per_block, lat_t - current_start_frame) + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] + + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state + + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + + if callback is not None: + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) + + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + + step_count += 1 + + output[:, :, fs:fe] = noisy_input + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) + + current_start_frame += bf + finally: + transformer_options.pop("ar_state", None) + + return output diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py new file mode 100644 index 000000000..d72f53602 --- /dev/null +++ b/comfy/ldm/wan/ar_model.py @@ -0,0 +1,276 @@ +""" +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for +autoregressive (frame-by-frame) video generation via Causal Forcing. + +Weight-compatible with the standard WanModel -- same layer names, same shapes. +The difference is purely in the forward pass: this model processes one temporal +block at a time and maintains a KV cache across blocks. + +Reference: https://github.com/thu-ml/Causal-Forcing +""" + +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import ( + sinusoidal_embedding_1d, + repeat_e, + WanModel, + WanAttentionBlock, +) +import comfy.ldm.common_dit +import comfy.model_management + + +class CausalWanSelfAttention(nn.Module): + """Self-attention with KV cache support for autoregressive inference.""" + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, + eps=1e-6, operation_settings={}): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + + def forward(self, x, freqs, kv_cache=None, transformer_options={}): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) + v = self.v(x).view(b, s, n, d) + + if kv_cache is None: + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v.view(b, s, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + else: + end = kv_cache["end"] + new_end = end + s + + # Roped K and plain V go into cache + kv_cache["k"][:, end:new_end] = k + kv_cache["v"][:, end:new_end] = v + kv_cache["end"] = new_end + + x = optimized_attention( + q.view(b, s, n * d), + kv_cache["k"][:, :new_end].view(b, new_end, n * d), + kv_cache["v"][:, :new_end].view(b, new_end, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + + x = self.o(x) + return x + + +class CausalWanAttentionBlock(WanAttentionBlock): + """Transformer block with KV-cached self-attention and cross-attention caching.""" + + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + self.self_attn = CausalWanSelfAttention( + dim, num_heads, window_size, qk_norm, eps, + operation_settings=operation_settings) + + def forward(self, x, e, freqs, context, context_img_len=257, + kv_cache=None, crossattn_cache=None, transformer_options={}): + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + + # Self-attention with optional KV cache + x = x.contiguous() + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, kv_cache=kv_cache, transformer_options=transformer_options) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y + + # Cross-attention with optional caching + if crossattn_cache is not None and crossattn_cache.get("is_init"): + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) + x_ca = optimized_attention( + q, crossattn_cache["k"], crossattn_cache["v"], + heads=self.num_heads, transformer_options=transformer_options) + x = x + self.cross_attn.o(x_ca) + else: + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if crossattn_cache is not None: + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) + crossattn_cache["v"] = self.cross_attn.v(context) + crossattn_cache["is_init"] = True + + # FFN + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class CausalWanModel(WanModel): + """ + Wan 2.1 diffusion backbone with causal KV-cache support. + + Same weight structure as WanModel -- loads identical state dicts. + Adds forward_block() for frame-by-frame autoregressive inference. + """ + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + image_model=None, + device=None, + dtype=None, + operations=None): + super().__init__( + model_type=model_type, patch_size=patch_size, text_len=text_len, + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, + wan_attn_block_class=CausalWanAttentionBlock, + device=device, dtype=dtype, operations=operations) + + def forward_block(self, x, timestep, context, start_frame, + kv_caches, crossattn_caches, clip_fea=None): + """ + Forward one temporal block for autoregressive inference. + + Args: + x: [B, C, block_frames, H, W] input latent for the current block + timestep: [B, block_frames] per-frame timesteps + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) + start_frame: temporal frame index for RoPE offset + kv_caches: list of per-layer KV cache dicts + crossattn_caches: list of per-layer cross-attention cache dicts + clip_fea: optional CLIP features for I2V + + Returns: + flow_pred: [B, C_out, block_frames, H, W] flow prediction + """ + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + bs, c, t, h, w = x.shape + + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # Per-frame time embedding + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # Text embedding (reuses crossattn_cache after first block) + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + # RoPE for current block's temporal position + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) + + # Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + kv_cache=kv_caches[i], + crossattn_cache=crossattn_caches[i]) + + # Head + x = self.head(x, e) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] + + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): + """Create fresh KV caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({ + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "end": 0, + }) + return caches + + def init_crossattn_caches(self, batch_size, device, dtype): + """Create fresh cross-attention caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({"is_init": False}) + return caches + + def reset_kv_caches(self, kv_caches): + """Reset KV caches to empty (reuse allocated memory).""" + for cache in kv_caches: + cache["end"] = 0 + + def reset_crossattn_caches(self, crossattn_caches): + """Reset cross-attention caches.""" + for cache in crossattn_caches: + cache["is_init"] = False + + @property + def head_dim(self): + return self.dim // self.num_heads + + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + ar_state = transformer_options.get("ar_state") + if ar_state is not None: + bs = x.shape[0] + block_frames = x.shape[2] + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) + return self.forward_block( + x=x, timestep=t_per_frame, context=context, + start_frame=ar_state["start_frame"], + kv_caches=ar_state["kv_caches"], + crossattn_caches=ar_state["crossattn_caches"], + clip_fea=clip_fea, + ) + + return super().forward(x, timestep, context, clip_fea=clip_fea, + time_dim_concat=time_dim_concat, + transformer_options=transformer_options, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index 94579fa3e..db5836022 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.wan.ar_model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1353,6 +1354,13 @@ class WAN21(BaseModel): return out +class WAN21_CausalAR(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.wan.ar_model.CausalWanModel) + self.image_to_video = False + + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..6ee50181c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -719,11 +719,15 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma +# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents) +# but is kept here so it appears in standard sampler dropdowns; sample_ar_video +# validates at runtime and raises a clear error for incompatible checkpoints. KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece", + "ar_video"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 07feb31b3..aa66e035f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1165,6 +1165,25 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) +class WAN21_CausalAR_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "causal_ar": True, + } + + sampling_settings = { + "shift": 5.0, + } + + def __init__(self, unet_config): + super().__init__(unet_config) + self.unet_config.pop("causal_ar", None) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.WAN21_CausalAR(self, device=device) + + class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1734,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py new file mode 100644 index 000000000..be9f2eaec --- /dev/null +++ b/comfy_extras/nodes_ar_video.py @@ -0,0 +1,49 @@ +""" +ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). + - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors +""" + +import torch +from typing_extensions import override + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + + +class EmptyARVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyARVideoLatent", + category="latent/video", + inputs=[ + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class ARVideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyARVideoLatent, + ] + + +async def comfy_entrypoint() -> ARVideoExtension: + return ARVideoExtension() diff --git a/nodes.py b/nodes.py index 37ceac2fc..4d674617f 100644 --- a/nodes.py +++ b/nodes.py @@ -2443,6 +2443,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_ar_video.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py",