From 65e2103b09f66e45438445fb0e99709ae7639869 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:51:48 +0200 Subject: [PATCH] feat(api-nodes): add Wan2.6 model to video nodes (#11357) --- comfy_api_nodes/nodes_wan.py | 162 ++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 67 deletions(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 2aab3c2ff..17b680e13 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,7 +1,5 @@ import re -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override @@ -21,26 +19,26 @@ from comfy_api_nodes.util import ( class Text2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) class Image2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) images: list[str] = Field(..., min_length=1, max_length=2) class Text2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) - audio_url: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) + audio_url: str | None = Field(None) class Image2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) img_url: str = Field(...) - audio_url: Optional[str] = Field(None) + audio_url: str | None = Field(None) class Txt2ImageParametersField(BaseModel): @@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel): class Image2ImageParametersField(BaseModel): - size: Optional[str] = Field(None) + size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) watermark: bool = Field(True) @@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel): class Text2VideoParametersField(BaseModel): size: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Image2VideoParametersField(BaseModel): resolution: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Text2ImageTaskCreationRequest(BaseModel): @@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel): class TaskCreationResponse(BaseModel): - output: Optional[TaskCreationOutputField] = Field(None) + output: TaskCreationOutputField | None = Field(None) request_id: str = Field(...) - code: Optional[str] = Field(None, description="The error code of the failed request.") - message: Optional[str] = Field(None, description="Details of the failed request.") + code: str | None = Field(None, description="Error code for the failed request.") + message: str | None = Field(None, description="Details about the failed request.") class TaskResult(BaseModel): - url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - results: Optional[list[TaskResult]] = Field(None) + results: list[TaskResult] | None = Field(None) class VideoTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - video_url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + video_url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusResponse(BaseModel): - output: Optional[ImageTaskStatusOutputField] = Field(None) + output: ImageTaskStatusOutputField | None = Field(None) request_id: str = Field(...) class VideoTaskStatusResponse(BaseModel): - output: Optional[VideoTaskStatusOutputField] = Field(None) + output: VideoTaskStatusOutputField | None = Field(None) request_id: str = Field(...) @@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode): node_id="WanTextToImageApi", display_name="Wan Text to Image", category="api node/image/Wan", - description="Generates image based on text prompt.", + description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( "model", @@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Int.Input( @@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Image.Input( "image", - tooltip="Single-image editing or multi-image fusion, maximum 2 images.", + tooltip="Single-image editing or multi-image fusion. Maximum 2 images.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), # redo this later as an optional combo of recommended resolutions @@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", # width: int = 1024, @@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode): ): n_images = get_number_of_images(image) if n_images not in (1, 2): - raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") + raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.") images = [] for i in image: images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) @@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode): node_id="WanTextToVideoApi", display_name="Wan Text to Video", category="api node/video/Wan", - description="Generates video based on text prompt.", + description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-t2v-preview"], - default="wan2.5-t2v-preview", + options=["wan2.5-t2v-preview", "wan2.6-t2v"], + default="wan2.6-t2v", tooltip="Model to use.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode): "1080p: 4:3 (1632x1248)", "1080p: 3:4 (1248x1632)", ], - default="480p: 1:1 (624x624)", + default="720p: 1:1 (960x960)", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="A 15-second duration is available only for the Wan 2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode): model: str, prompt: str, negative_prompt: str = "", - size: str = "480p: 1:1 (624x624)", + size: str = "720p: 1:1 (960x960)", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): + if "480p" in size and model == "wan2.6-t2v": + raise ValueError("The Wan 2.6 model does not support 480p.") + if duration == 15 and model == "wan2.5-t2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") width, height = RES_IN_PARENS.search(size).groups() audio_url = None if audio is not None: @@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode): node_id="WanImageToVideoApi", display_name="Wan Image to Video", category="api node/video/Wan", - description="Generates video based on the first frame and text prompt.", + description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-i2v-preview"], - default="wan2.5-i2v-preview", + options=["wan2.5-i2v-preview", "wan2.6-i2v"], + default="wan2.6-i2v", tooltip="Model to use.", ), IO.Image.Input( @@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode): "720P", "1080P", ], - default="480P", + default="720P", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="Duration 15 available only for WAN2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", - resolution: str = "480P", + resolution: str = "720P", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") + if "480P" in resolution and model == "wan2.6-i2v": + raise ValueError("The Wan 2.6 model does not support 480P.") + if duration == 15 and model == "wan2.5-i2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: @@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),