diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ea123acb4..d1ca9926e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1621,3 +1621,249 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='scail', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=5120, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + in_dim_ref_conv=None, + wan_attn_block_class=WanAttentionBlock, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) + + self.text_embedding = nn.Sequential( + operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'), + operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + + self.time_embedding = nn.Sequential( + operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + for i in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) + + d = dim // num_heads + self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) + + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + freqs=None, + transformer_options={}, + pose_latents=None, + **kwargs, + ): + + reference_latent = kwargs.get("reference_latent", None) + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, transformer_options={}): + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + segments = [img_ids] # Start with main frames + + # Pose frames position IDs + if pose_latents is not None: + pose_frame_shape = pose_latents.shape + F_pose, H_pose, W_pose = pose_frame_shape[-3], pose_frame_shape[-2], pose_frame_shape[-1] + + downscale = H_pose != h + pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0]) + pose_h_len_full = (((H_pose * (2 if downscale else 1)) + (self.patch_size[1] // 2)) // self.patch_size[1]) # 2x height + pose_w_len_full = (((W_pose * (2 if downscale else 1)) + (self.patch_size[2] // 2)) // self.patch_size[2]) # 2x width + + pose_img_ids = torch.zeros((pose_f_len_full, pose_h_len_full, pose_w_len_full, 3), device=device, dtype=dtype) + global_h_offset, global_w_offset = 0, 120 # global spatial offset to separate pose from main frames spatially (SCAIL uses 120 as offset) + pose_img_ids[:, :, :, 0] = pose_img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (pose_f_len_full - 1), steps=pose_f_len_full, device=device, dtype=dtype).reshape(-1, 1, 1) + pose_img_ids[:, :, :, 1] = pose_img_ids[:, :, :, 1] + torch.linspace(global_h_offset, global_h_offset + pose_h_len_full - 1, steps=pose_h_len_full, device=device, dtype=dtype).reshape(1, -1, 1) + pose_img_ids[:, :, :, 2] = pose_img_ids[:, :, :, 2] + torch.linspace(global_w_offset, global_w_offset + pose_w_len_full - 1, steps=pose_w_len_full, device=device, dtype=dtype).reshape(1, 1, -1) + + segments.append(pose_img_ids.reshape(1, -1, pose_img_ids.shape[-1])) + + combined_img_ids = torch.cat(segments, dim=1) + + freqs = self.rope_embedder(combined_img_ids).movedim(1, 2) + + # Downsample pose frequencies to match actual pose input resolution + if pose_latents is not None and downscale: + pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1]) + pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2]) + + pose_start_idx = freqs.shape[1] - pose_f_len_full * pose_h_len_full * pose_w_len_full + main_freqs, pose_freqs = freqs[:, :pose_start_idx], freqs[:, pose_start_idx:] + + B, _, heads, dim, _, _ = pose_freqs.shape + # Reshape and pool: (B, L, heads, dim, 2, 2) -> pool H,W -> (B, L', heads, dim, 2, 2) + pose_freqs = pose_freqs.reshape(B, pose_f_len_full, pose_h_len_full, pose_w_len_full, heads, dim, 2, 2) + pose_freqs = pose_freqs.permute(0, 1, 4, 5, 6, 7, 2, 3).reshape(-1, pose_h_len_full, pose_w_len_full) + pose_freqs = torch.nn.functional.avg_pool2d(pose_freqs, kernel_size=2, stride=2) + pose_freqs = pose_freqs.reshape(B, pose_f_len_full, heads, dim, 2, 2, pose_h_len_actual, pose_w_len_actual) + pose_freqs = pose_freqs.permute(0, 1, 6, 7, 2, 3, 4, 5).reshape(B, -1, heads, dim, 2, 2) + + freqs = torch.cat([main_freqs, pose_freqs], dim=1) + + return freqs + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + if "reference_latent" in kwargs: + t_len += 1 + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/model_base.py b/comfy/model_base.py index 2f49578f6..5b844340e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1466,6 +1466,55 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + pose_start = kwargs.get("pose_start", 0.0) + pose_end = kwargs.get("pose_end", 1.0) + out['pose_start'] = comfy.conds.CONDConstant(pose_start) + out['pose_end'] = comfy.conds.CONDConstant(pose_end) + + return out + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs): + if t >= self.model_sampling.percent_to_sigma(pose_start) or t <= self.model_sampling.percent_to_sigma(pose_end): + kwargs.pop("pose_latents", None) + + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._apply_model, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options) + ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30ea03e8e..306c70703 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -496,6 +496,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c28be1716..240b6dbe5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1667,6 +1677,6 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index effa994d1..eff21ecce 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,6 +1456,66 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=10.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="End step to use pose conditioning."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + if reference_image is None: + reference_image = torch.zeros((1, height, width, 3)) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end}) + negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1476,6 +1536,7 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, + WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: