From 7dbd5dfe91f057b83dcba0c127f712f6d71f7def Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 1 Dec 2025 10:27:17 -0800 Subject: [PATCH 01/29] bump comfyui-frontend-package to 1.32.10 (#11018) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 386477808..045b2ac54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.32.9 +comfyui-frontend-package==1.32.10 comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch From 2640acb31ccfddee57ba22d5245bf456e8dffe53 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Dec 2025 14:13:48 -0800 Subject: [PATCH 02/29] Update qwen tokenizer to add qwen 3 tokens. (#11029) Doesn't actually change anything for current workflows because none of the current models have a template with the think tokens. --- .../qwen25_tokenizer/tokenizer_config.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json index 67688e82c..df5b5d7fe 100644 --- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json +++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json @@ -179,36 +179,36 @@ "special": false }, "151665": { - "content": "<|img|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151666": { - "content": "<|endofimg|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151667": { - "content": "<|meta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151668": { - "content": "<|endofmeta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false } }, "additional_special_tokens": [ From 1cb7e22a95701f2619d1ddf5683ea221b58a0c13 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 2 Dec 2025 02:11:52 +0200 Subject: [PATCH 03/29] [API Nodes] add Kling O1 model support (#11025) * feat(api-nodes): add Kling O1 model support * fix: increase max allowed duration to 10.05 seconds * fix(VideoInput): respect "format" argument --- comfy_api/latest/_input_impl/video_types.py | 5 +- comfy_api_nodes/apis/kling_api.py | 66 +++ comfy_api_nodes/nodes_kling.py | 444 +++++++++++++++++++- comfy_api_nodes/util/upload_helpers.py | 3 +- 4 files changed, 499 insertions(+), 19 deletions(-) create mode 100644 comfy_api_nodes/apis/kling_api.py diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index bde37f90a..7231bf13c 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput): raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + extra_kwargs = {} + if format != VideoContainer.AUTO: + extra_kwargs["format"] = format.value + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py new file mode 100644 index 000000000..0a3b447c5 --- /dev/null +++ b/comfy_api_nodes/apis/kling_api.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field + + +class OmniProText2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniParamImage(BaseModel): + image_url: str = Field(...) + type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'") + + +class OmniParamVideo(BaseModel): + video_url: str = Field(...) + refer_type: str | None = Field(..., description="Can be 'base' or 'feature'") + keep_original_sound: str = Field(..., description="'yes' or 'no'") + + +class OmniProFirstLastFrameRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniProReferences2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'") + image_list: list[OmniParamImage] | None = Field( + None, max_length=7, description="Max length 4 when video is present." + ) + video_list: list[OmniParamVideo] | None = Field(None, max_length=1) + duration: str | None = Field(..., description="From 3 to 10.") + prompt: str = Field(...) + mode: str = Field("pro") + + +class TaskStatusVideoResult(BaseModel): + duration: str | None = Field(None, description="Total video duration") + id: str | None = Field(None, description="Generated video ID") + url: str | None = Field(None, description="URL for generated video") + + +class TaskStatusVideoResults(BaseModel): + videos: list[TaskStatusVideoResult] | None = Field(None) + + +class TaskStatusVideoResponseData(BaseModel): + created_at: int | None = Field(None, description="Task creation time") + updated_at: int | None = Field(None, description="Task update time") + task_status: str | None = None + task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") + task_id: str | None = Field(None, description="Task ID") + task_result: TaskStatusVideoResults | None = Field(None) + + +class TaskStatusVideoResponse(BaseModel): + code: int | None = Field(None, description="Error code") + message: str | None = Field(None, description="Error message") + request_id: str | None = Field(None, description="Request ID") + data: TaskStatusVideoResponseData | None = Field(None) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 23a7f55f1..850c44db6 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,13 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -import math import logging - -from typing_extensions import override +import math import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( KlingCameraControl, KlingCameraConfig, @@ -48,23 +48,31 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) +from comfy_api_nodes.apis.kling_api import ( + OmniParamImage, + OmniParamVideo, + OmniProFirstLastFrameRequest, + OmniProReferences2VideoRequest, + OmniProText2VideoRequest, + TaskStatusVideoResponse, +) from comfy_api_nodes.util import ( - validate_image_dimensions, + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, validate_video_dimensions, validate_video_duration, - tensor_to_base64_string, - validate_string, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - upload_video_to_comfyapi, - download_url_to_video_output, - sync_op, - ApiEndpoint, - poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO, Input KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -202,6 +210,20 @@ VOICES_CONFIG = { } +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), + response_model=TaskStatusVideoResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -449,7 +471,7 @@ async def execute_video_effect( image_1: torch.Tensor, image_2: torch.Tensor | None = None, model_mode: KlingVideoGenMode | None = None, -) -> tuple[VideoFromFile, str, str]: +) -> tuple[InputImpl.VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( model_name=model_name, @@ -736,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode): ) +class OmniProTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProTextToVideoNode", + display_name="Kling Omni Text to Video (Pro)", + category="api node/video/Kling", + description="Use text prompts to generate videos with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProFirstLastFrameNode", + display_name="Kling Omni First-Last-Frame to Video (Pro)", + category="api node/video/Kling", + description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("duration", options=["5", "10"]), + IO.Image.Input("first_frame"), + IO.Image.Input( + "end_frame", + optional=True, + tooltip="An optional end frame for the video. " + "This cannot be used simultaneously with 'reference_images'.", + ), + IO.Image.Input( + "reference_images", + optional=True, + tooltip="Up to 6 additional reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image | None = None, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if end_frame is not None and reference_images is not None: + raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [ + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + type="first_frame", + ) + ] + if end_frame is not None: + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_list.append( + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + type="end_frame", + ) + ) + if reference_images is not None: + if get_number_of_images(reference_images) > 6: + raise ValueError("The maximum number of reference images allowed is 6.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageToVideoNode", + display_name="Kling Omni Image to Video (Pro)", + category="api node/video/Kling", + description="Use up to 7 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Image.Input( + "reference_images", + tooltip="Up to 7 reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_images: Input.Image, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if get_number_of_images(reference_images) > 7: + raise ValueError("The maximum number of reference images is 7.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [] + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProVideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProVideoToVideoNode", + display_name="Kling Omni Video to Video (Pro)", + category="api node/video/Kling", + description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Video.Input("reference_video", tooltip="Video to use as a reference."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + refer_type="feature", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProEditVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProEditVideoNode", + display_name="Kling Omni Edit Video (Pro)", + category="api node/video/Kling", + description="Edit an existing video with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + refer_type="base", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -1162,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): category="api node/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ - IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), + IO.Image.Input( + "image", + tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1", + ), IO.Combo.Input( "effect_scene", options=[i.value for i in KlingSingleImageEffectsScene], @@ -1525,6 +1930,11 @@ class KlingExtension(ComfyExtension): KlingImageGenerationNode, KlingSingleImageVideoEffectNode, KlingDualCharacterVideoEffectNode, + OmniProTextToVideoNode, + OmniProFirstLastFrameNode, + OmniProImageToVideoNode, + OmniProVideoToVideoNode, + OmniProEditVideoNode, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b9019841f..0532bea9a 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -103,6 +103,7 @@ async def upload_video_to_comfyapi( container: VideoContainer = VideoContainer.MP4, codec: VideoCodec = VideoCodec.H264, max_duration: Optional[int] = None, + wait_label: str | None = "Uploading", ) -> str: """ Uploads a single video to ComfyUI API and returns its download URL. @@ -127,7 +128,7 @@ async def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) async def upload_file_to_comfyapi( From 30c259cac8c08ff8d015f9aff3151cb525c9b702 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 Dec 2025 20:25:35 -0500 Subject: [PATCH 04/29] ComfyUI version v0.3.76 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index fa4b4f4b0..4b039356e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.75" +__version__ = "0.3.76" diff --git a/pyproject.toml b/pyproject.toml index 9009e65fe..02b94a0ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.75" +version = "0.3.76" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 878db3a727c1c6049bc1c4959cdfabc35eaf3d56 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:56:17 -0800 Subject: [PATCH 05/29] Implement the Ovis image model. (#11030) --- comfy/ldm/chroma/model.py | 3 +- comfy/ldm/flux/layers.py | 68 +++++++++++++++++++++-------------- comfy/ldm/flux/model.py | 21 ++++++++--- comfy/model_detection.py | 10 +++++- comfy/sd.py | 13 +++++-- comfy/text_encoders/llama.py | 31 ++++++++++++++++ comfy/text_encoders/ovis.py | 69 ++++++++++++++++++++++++++++++++++++ nodes.py | 2 +- 8 files changed, 182 insertions(+), 35 deletions(-) create mode 100644 comfy/text_encoders/ovis.py diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a72f8cc47..2e8ef0687 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -40,7 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int - + txt_ids_dims: list + vec_in_dim: int diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 2472ab79c..60f2bdae2 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module): def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -140,7 +169,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module): self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) @@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module): self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) self.flipped_img_txt = flipped_img_txt @@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module): modulation=True, mlp_silu_act=False, bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp if mlp_silu_act: self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) self.mlp_act = SiLUActivation() else: self.mlp_act = nn.GELU(approximate="tanh") + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out @@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v # compute activation in mlp stream, cat again and run second linear layer - mlp = self.mlp_act(mlp) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index d5674dea6..f40c2a7a9 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,7 +15,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, - Modulation + Modulation, + RMSNorm ) @dataclass @@ -34,11 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list global_modulation: bool = False mlp_silu_act: bool = False ops_bias: bool = True default_ref_method: str = "offset" ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -76,6 +80,11 @@ class Flux(nn.Module): ) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None + self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -86,6 +95,7 @@ class Flux(nn.Module): modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -94,7 +104,7 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -150,6 +160,8 @@ class Flux(nn.Module): y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) vec_orig = vec @@ -332,8 +344,9 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) - if len(self.params.axes_dim) == 4: # Flux 2 - txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7afe4a798..7d0517e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 2000 dit_config["out_channels"] = 128 dit_config["global_modulation"] = True - dit_config["vec_in_dim"] = None dit_config["mlp_silu_act"] = True dit_config["qkv_bias"] = False dit_config["ops_bias"] = False dit_config["default_ref_method"] = "index" dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] patch_size = 1 else: dit_config["image_model"] = "flux" @@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 10000 dit_config["out_channels"] = 16 dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] patch_size = 2 dit_config["in_channels"] = 16 @@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + else: + dit_config["vec_in_dim"] = None dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') @@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview diff --git a/comfy/sd.py b/comfy/sd.py index 9eeb0c45a..f9e5efab5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.ovis import comfy.model_patcher import comfy.lora @@ -956,6 +957,7 @@ class CLIPType(Enum): QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 + OVIS = 21 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -987,6 +989,7 @@ class TEModel(Enum): MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 + QWEN3_2B = 17 def detect_te_model(sd): @@ -1020,9 +1023,12 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1150,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_4B: clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index cd4b5f76c..0d07ac8c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -100,6 +100,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..81c9bd51c --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,69 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/nodes.py b/nodes.py index 495dec806..d5e5dc228 100644 --- a/nodes.py +++ b/nodes.py @@ -939,7 +939,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From c55dc857d5da5af203caf720ed7056047d382544 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 1 Dec 2025 17:56:38 -0800 Subject: [PATCH 06/29] bump comfyui-frontend-package to 1.33.10 (#11028) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 045b2ac54..f98848e20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.32.10 +comfyui-frontend-package==1.33.10 comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch From b4a20acc54b0b94dc05a1bd09dc0b54dd12203f1 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:32:52 +0900 Subject: [PATCH 07/29] feat: Support ComfyUI-Manager for pip version (#7555) --- comfy/cli_args.py | 7 +++++++ comfy_api/feature_flags.py | 1 + main.py | 30 ++++++++++++++++++++++++++++++ manager_requirements.txt | 1 + nodes.py | 9 +++++++++ server.py | 8 +++++++- 6 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 manager_requirements.txt diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5f0dfaa10..209fc185b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") +parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") +manager_group = parser.add_mutually_exclusive_group() +manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") + + vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") @@ -168,6 +174,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0d4389a6e..bfb77eb5f 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -13,6 +13,7 @@ from comfy.cli_args import args SERVER_FEATURE_FLAGS: Dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes + "extension": {"manager": {"supports_v4": True}}, } diff --git a/main.py b/main.py index e1b0f1620..0cd815d9e 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags + if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -22,6 +23,23 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + +def handle_comfyui_manager_unavailable(): + if not args.windows_standalone_build: + logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + args.enable_manager = False + + +if args.enable_manager: + if importlib.util.find_spec("comfyui_manager"): + import comfyui_manager + + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + handle_comfyui_manager_unavailable() + else: + handle_comfyui_manager_unavailable() + + def apply_custom_paths(): # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") @@ -79,6 +97,11 @@ def execute_prestartup_script(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + continue + if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": continue @@ -101,6 +124,10 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() + +if args.enable_manager: + comfyui_manager.prestartup() + execute_prestartup_script() @@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: + comfyui_manager.start() + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, diff --git a/manager_requirements.txt b/manager_requirements.txt new file mode 100644 index 000000000..52cc5389c --- /dev/null +++ b/manager_requirements.txt @@ -0,0 +1 @@ +comfyui_manager==4.0.3b3 diff --git a/nodes.py b/nodes.py index d5e5dc228..4c910a34b 100644 --- a/nodes.py +++ b/nodes.py @@ -43,6 +43,9 @@ import folder_paths import latent_preview import node_helpers +if args.enable_manager: + import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -2243,6 +2246,12 @@ async def init_external_custom_nodes(): if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + logging.info(f"Blocked by policy: {module_path}") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) diff --git a/server.py b/server.py index fca5050bd..e3bd056d9 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,9 @@ from protocol import BinaryEventTypes # Import cache control middleware from middleware.cache_middleware import cache_control +if args.enable_manager: + import comfyui_manager + async def send_socket_catch_exception(function, message): try: await function(message) @@ -212,6 +215,9 @@ class PromptServer(): if args.disable_api_nodes: middlewares.append(create_block_external_middleware()) + if args.enable_manager: + middlewares.append(comfyui_manager.create_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() @@ -599,7 +605,7 @@ class PromptServer(): system_stats = { "system": { - "os": os.name, + "os": sys.platform, "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, From a17cf1c3871ad582c85c2bb6fddb63ec9c6df0ce Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Mon, 1 Dec 2025 19:40:44 -0800 Subject: [PATCH 08/29] Add @guill as a code owner (#11031) --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/CODEOWNERS b/CODEOWNERS index b7aca9b26..51acc4986 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,4 @@ # Admins * @comfyanonymous * @kosinkadink +* @guill From 44baa0b7f32dd0c2ff0a9898aeb6c7929d855cd3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Dec 2025 11:46:29 -0800 Subject: [PATCH 09/29] Fix CODEOWNERS formatting to have all on the same line, otherwise only last line applies (#11053) --- CODEOWNERS | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 51acc4986..4d5448636 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,4 +1,2 @@ # Admins -* @comfyanonymous -* @kosinkadink -* @guill +* @comfyanonymous @kosinkadink @guill From 33d6aec3b70bc6f3e5bba26c85bd8f3bb1380d08 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:50:13 +0200 Subject: [PATCH 10/29] add check for the format arg type in VideoFromComponents.save_to function (#11046) * add check for the format var type in VideoFromComponents.save_to function * convert "format" to VideoContainer enum --- comfy_api/latest/_input_impl/video_types.py | 2 +- comfy_extras/nodes_video.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 7231bf13c..a4cd3737d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -337,7 +337,7 @@ class VideoFromComponents(VideoInput): if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") extra_kwargs = {} - if format != VideoContainer.AUTO: + if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 69fabb12e..6cf6e39bf 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: + def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode): file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=format, + format=VideoContainer(format), codec=codec, metadata=saved_metadata ) From daaceac769a1355ab975758ede064317ea7514b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:11:58 -0800 Subject: [PATCH 11/29] Hack to make zimage work in fp16. (#11057) --- comfy/ldm/lumina/model.py | 18 +++++++++++------- comfy/supported_models.py | 2 ++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 7d7e9112c..070b5da09 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -22,6 +22,10 @@ def modulate(x, scale): # Core NextDiT Model # ############################################################################# +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x class JointAttention(nn.Module): """Multi-head attention module.""" @@ -169,7 +173,7 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + return clamp_fp16(F.silu(x1) * x3) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module): scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( - self.attention( + clamp_fp16(self.attention( modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( - self.feed_forward( + clamp_fp16(self.feed_forward( modulate(self.ffn_norm1(x), scale_mlp), - ) + )) ) else: assert adaln_input is None x = x + self.attention_norm2( - self.attention( + clamp_fp16(self.attention( self.attention_norm1(x), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + self.ffn_norm2( self.feed_forward( diff --git a/comfy/supported_models.py b/comfy/supported_models.py index af8120400..afd97160b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1027,6 +1027,8 @@ class ZImage(Lumina2): memory_usage_factor = 1.7 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) From 277237ccc1499bac7fcd221a666dfe7a32ac4206 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 08:24:19 +1000 Subject: [PATCH 12/29] attention: use flag based OOM fallback (#11038) Exception ref all local variables for the lifetime of exception context. Just set a flag and then if to dump the exception before falling back. --- comfy/ldm/modules/attention.py | 3 +++ comfy/ldm/modules/diffusionmodules/model.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..a8800ded0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + exception_fallback = True + if exception_fallback: if tensor_layout == "NHD": q, k, v = map( lambda t: t.transpose(1, 2), diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 4245eedca..de1e01cc8 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -279,6 +279,7 @@ def pytorch_attention(q, k, v): orig_shape = q.shape B = orig_shape[0] C = orig_shape[1] + oom_fallback = False q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), @@ -289,6 +290,8 @@ def pytorch_attention(q, k, v): out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") + oom_fallback = True + if oom_fallback: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out From b94d394a64dd0af06bca44b96c66549bb463331d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:38:31 -0800 Subject: [PATCH 13/29] Support Z Image alibaba pai fun controlnets. (#11062) These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it. --- comfy/ldm/lumina/controlnet.py | 113 ++++++++++++++++++++++++++++++ comfy/ldm/lumina/model.py | 24 ++++--- comfy_extras/nodes_model_patch.py | 101 +++++++++++++++++++++++++- 3 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 comfy/ldm/lumina/controlnet.py diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py new file mode 100644 index 000000000..fd7ce3b5c --- /dev/null +++ b/comfy/ldm/lumina/controlnet.py @@ -0,0 +1,113 @@ +import torch +from torch import nn + +from .model import JointTransformerBlock + +class ZImageControlTransformerBlock(JointTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + operation_settings=None, + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + +class ZImage_Control(torch.nn.Module): + def __init__( + self, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + multiple_of: int = 256, + ffn_dim_multiplier: float = (8.0 / 3.0), + norm_eps: float = 1e-5, + qk_norm: bool = True, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__() + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.additional_in_dim = 0 + self.control_in_dim = 16 + n_refiner_layers = 2 + self.n_control_layers = 6 + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=i, + operation_settings=operation_settings, + ) + for i in range(self.n_control_layers) + ] + ) + + all_x_embedder = {} + patch_size = 2 + f_patch_size = 1 + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): + patch_size = 2 + f_patch_size = 1 + pH = pW = patch_size + B, C, H, W = control_context.shape + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + x_attn_mask = None + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 070b5da09..f1c1a0ec3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -568,7 +568,7 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -585,16 +585,24 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + patches = transformer_options.get("patches", {}) transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) - freqs_cis = freqs_cis.to(x.device) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + freqs_cis = freqs_cis.to(img.device) - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + for i, layer in enumerate(self.layers): + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + img = self.final_layer(img, adaln_input) + img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -x + return -img diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 783c59b6b..c61810dbf 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -6,6 +6,7 @@ import comfy.ops import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats +import comfy.ldm.lumina.controlnet class BlockWiseControlBlock(torch.nn.Module): @@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module): return embedding +def z_image_convert(sd): + replace_keys = {".attention.to_out.0.bias": ".attention.out.bias", + ".attention.norm_k.weight": ".attention.k_norm.weight", + ".attention.norm_q.weight": ".attention.q_norm.weight", + ".attention.to_out.0.weight": ".attention.out.weight" + } + + out_sd = {} + for k in sorted(sd.keys()): + w = sd[k] + + k_out = k + if k_out.endswith(".attention.to_k.weight"): + cc = [w] + continue + if k_out.endswith(".attention.to_q.weight"): + cc = [w] + cc + continue + if k_out.endswith(".attention.to_v.weight"): + cc = cc + [w] + w = torch.cat(cc, dim=0) + k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") + + for r, rr in replace_keys.items(): + k_out = k_out.replace(r, rr) + out_sd[k_out] = w + + return out_sd + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -211,6 +241,9 @@ class ModelPatchLoader: elif 'feature_embedder.mid_layer_norm.bias' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet + sd = z_image_convert(sd) + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -263,6 +296,69 @@ class DiffSynthCnetPatch: def models(self): return [self.model_patch] +class ZImageControlPatch: + def __init__(self, model_patch, vae, image, strength): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + self.encoded_image = self.encode_latent_cond(image) + self.encoded_image_size = (image.shape[1], image.shape[2]) + self.temp_data = None + + def encode_latent_cond(self, image): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + txt = kwargs.get("txt") + pe = kwargs.get("pe") + vec = kwargs.get("vec") + block_index = kwargs.get("block_index") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + comfy.model_management.load_models_gpu(loaded_models) + + cnet_index = (block_index // 5) + cnet_index_float = (block_index / 5) + + kwargs.pop("img") # we do ops in place + kwargs.pop("txt") + + cnet_blocks = self.model_patch.model.n_control_layers + if cnet_index_float > (cnet_blocks - 1): + self.temp_data = None + return kwargs + + if self.temp_data is None or self.temp_data[0] > cnet_index: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None + + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + self.temp_data = None + return self + + def models(self): + return [self.model_patch] + class QwenImageDiffsynthControlnet: @classmethod def INPUT_TYPES(s): @@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet: mask = mask.unsqueeze(2) mask = 1.0 - mask - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) + if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): + model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + else: + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From 3f512f5659cfbb3c53999cde6ff557591740252b Mon Sep 17 00:00:00 2001 From: Jim Heising Date: Tue, 2 Dec 2025 19:29:27 -0800 Subject: [PATCH 14/29] Added PATCH method to CORS headers (#11066) Added PATCH http method to access-control-allow-header-methods header because there are now PATCH endpoints exposed in the API. See https://github.com/comfyanonymous/ComfyUI/blob/277237ccc1499bac7fcd221a666dfe7a32ac4206/api_server/routes/internal/internal_routes.py#L34 for an example of an API endpoint that uses the PATCH method. --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index e3bd056d9..ac4f42222 100644 --- a/server.py +++ b/server.py @@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str): response = await handler(request) response.headers['Access-Control-Allow-Origin'] = allowed_origin - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Credentials'] = 'true' return response From 73f5649196f472d3719e2e7513e0a9d029cc3e38 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:49:29 +1000 Subject: [PATCH 15/29] Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995) * hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe33 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe33 Initial commit. HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). --- comfy/ldm/hunyuan_video/upsampler.py | 3 +- comfy/ldm/hunyuan_video/vae_refiner.py | 94 +++------ comfy/ldm/modules/diffusionmodules/model.py | 207 ++++++++++++++------ 3 files changed, 174 insertions(+), 130 deletions(-) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 9f5e91a59..85f515f67 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm import model_management, model_patcher class SRResidualCausalBlock3D(nn.Module): diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 9f750dcc4..ddf77cd0e 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,42 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init -class NoPadConv3d(nn.Module): - def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): - super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, x): - return self.conv(x) - - -def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): - - x = xl[0] - xl.clear() - - if conv_carry_out is not None: - to_push = x[:, :, -2:, :, :].clone() - conv_carry_out.append(to_push) - - if isinstance(op, NoPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') - else: - carry_len = conv_carry_in[0].shape[2] - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') - - out = op(x) - - return out - class RMS_norm(nn.Module): def __init__(self, dim): @@ -49,7 +19,7 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -109,7 +79,7 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -163,23 +133,6 @@ class UpSmpl(nn.Module): return h + x -class HunyuanRefinerResnetBlock(ResnetBlock): - def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): - super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - - def forward(self, x, conv_carry_in=None, conv_carry_out=None): - h = x - h = [ self.swish(self.norm1(x)) ] - h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - h = [ self.dropout(self.swish(self.norm2(h))) ] - h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x+h - class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): @@ -191,7 +144,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -206,9 +159,10 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -218,9 +172,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -246,22 +200,20 @@ class Encoder(nn.Module): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None + x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out @@ -288,7 +240,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -298,9 +250,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -308,9 +260,10 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -340,7 +293,7 @@ class Decoder(nn.Module): conv_carry_out = None for stage in self.up: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) @@ -350,10 +303,7 @@ class Decoder(nn.Module): conv_carry_in = conv_carry_out del x - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index de1e01cc8..681a55db5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -89,29 +126,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -127,17 +159,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -183,23 +218,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -520,9 +555,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -535,6 +575,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -561,10 +602,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -590,15 +636,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [ x1 ] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions-1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -607,15 +680,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [ nonlinearity(h) ] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -629,12 +702,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -709,29 +788,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out From c120eee5bacca643062657d2a7efad83c7d4d828 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Dec 2025 21:17:13 -0800 Subject: [PATCH 16/29] Add MatchType, DynamicCombo, and Autogrow support to V3 Schema (#10832) * Added output_matchtypes to generated json for v3, initial backend support for MatchType, created nodes_logic.py and added SwitchNode * Fixed providing list of allowed_types * Add workaround in validation.py for V3 Combo outputs not working as Combo inputs * Make match type receive_type pass validation * Also add MatchType check to input_type in validation - will likely trigger when connecting to non-lazy stuff * Make sure this PR only has MatchType stuff * Initial work on DynamicCombo * Add get_dynamic function, not yet filled out correctly * Mark Switch node as Beta * Make sure other unfinished dynamic types are not accidentally used * Send DynamicCombo.Option inputs in the same format as normal v1 inputs * add dynamic combo test node * Support validation of inputs and outputs * Add missing input params to DynamicCombo.Input * Add get_all function to inputs for id validation purposes * Fix imports for v3 returning everything when doing io/ui/IO/UI instead of what is in __all__ of _io.py and _ui.py * Modifying behavior of get_dynamic in V3 + serialization so can be used in execution code * Fix v3 schema validation code after changes * Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet) * Add nesting of inputs on DynamicCombo during execution * Work with latest frontend commits * Fix cringe arrows * frontend will no longer namespace dynamic inputs widgets so reflect that in code, refactor build_nested_inputs * Prepare Autogrow support for the love of the game * satisfy ruff * Create test nodes for Autogrow to collab with frontend development * Add nested combo to DCTestNode * Remove array support from build_nested_inputs, properly handle missing expected values * Make execution.validate_inputs properly validate required dynamic inputs, renamed dynamic_data to dynamic_paths for clarity * MatchType does not need any DynamicInput/Output features on backend; will increase compatibility with dynamic types * Probably need this for ruff check * Change MatchType to have template be the first and only required param; output id's do nothing right now, so no need * Fix merge regression with LatentUpscaleModel type not being put in __all__ for _io.py, fix invalid type hint for validate_inputs * Make Switch node inputs optional, disallow both inputs from being missing, and still work properly with lazy; when one input is missing, use the other no matter what the switch is set to * Satisfy ruff * Move MatchType code above the types that inherit from DynamicInput * Add DynamicSlot type, awaiting frontend support * Make curr_prefix creation happen in Autogrow, move curr_prefix in DynamicCombo to only be created if input exists in live_inputs * I was confused, fixing accidentally redundant curr_prefix addition in Autogrow * Make sure Autogrow inputs are force_input = True when WidgetInput, fix runtime validation by removing original input from expected inputs, fix min/max bounds, change test nodes slightly * Remove unnecessary id usage in Autogrow test node outputs * Commented out Switch node + test nodes * Remove commented out code from Autogrow * Make TemplatePrefix max more clear, allow max == 1 * Replace all dict[str] with dict[str, Any] * Renamed add_to_dict_live_inputs to expand_schema_for_dynamic * Fixed typo in DynamicSlot input code * note about live_inputs not being present soon in get_v1_info (internal function anyway) * For now, hide DynamicCombo and Autogrow from public interface * Removed comment --- comfy_api/latest/__init__.py | 4 +- comfy_api/latest/_io.py | 416 ++++++++++++++++++++++++++------- comfy_api/latest/_io_public.py | 1 + comfy_api/latest/_ui_public.py | 1 + comfy_api/v0_0_2/__init__.py | 6 +- comfy_execution/validation.py | 6 + comfy_extras/nodes_logic.py | 155 ++++++++++++ execution.py | 40 ++-- nodes.py | 1 + 9 files changed, 525 insertions(+), 105 deletions(-) create mode 100644 comfy_api/latest/_io_public.py create mode 100644 comfy_api/latest/_ui_public.py create mode 100644 comfy_extras/nodes_logic.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 176ae36e0..0fa01d1e7 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL -from . import _io as io -from . import _ui as ui +from . import _io_public as io +from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 79c0722a9..257f07c42 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -4,6 +4,7 @@ import copy import inspect from abc import ABC, abstractmethod from collections import Counter +from collections.abc import Iterable from dataclasses import asdict, dataclass from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING @@ -150,6 +151,9 @@ class _IO_V3: def __init__(self): pass + def validate(self): + pass + @property def io_type(self): return self.Parent.io_type @@ -182,6 +186,9 @@ class Input(_IO_V3): def get_io_type(self): return _StringIOType(self.io_type) + def get_all(self) -> list[Input]: + return [self] + class WidgetInput(Input): ''' Base class for a V3 Input with widget. @@ -814,13 +821,61 @@ class MultiType: else: return super().as_dict() +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): + self.template_id = template_id + # account for syntactic sugar + if not isinstance(allowed_types, Iterable): + allowed_types = [allowed_types] + for t in allowed_types: + if not isinstance(t, type): + if not isinstance(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}") + else: + if not issubclass(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}") + self.allowed_types = allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": ",".join([t.io_type for t in self.allowed_types]), + } + + class Input(Input): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(Output): + def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - @abstractmethod def get_dynamic(self) -> list[Input]: - ... + return [] + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + pass + class DynamicOutput(Output, ABC): ''' @@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC): is_output_list=False): super().__init__(id, display_name, tooltip, is_output_list) - @abstractmethod def get_dynamic(self) -> list[Output]: - ... + return [] @comfytype(io_type="COMFY_AUTOGROW_V3") -class AutogrowDynamic(ComfyTypeI): - Type = list[Any] - class Input(DynamicInput): - def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template_input = template_input - if min is not None: - assert(min >= 1) - if max is not None: - assert(max >= 1) +class Autogrow(ComfyTypeI): + Type = dict[str, Any] + _MaxNames = 100 # NOTE: max 100 names for sanity + + class _AutogrowTemplate: + def __init__(self, input: Input): + # dynamic inputs are not allowed as the template input + assert(not isinstance(input, DynamicInput)) + self.input = copy.copy(input) + if isinstance(self.input, WidgetInput): + self.input.force_input = True + self.names: list[str] = [] + self.cached_inputs = {} + + def _create_input(self, input: Input, name: str): + new_input = copy.copy(self.input) + new_input.id = name + return new_input + + def _create_cached_inputs(self): + for name in self.names: + self.cached_inputs[name] = self._create_input(self.input, name) + + def get_all(self) -> list[Input]: + return list(self.cached_inputs.values()) + + def as_dict(self): + return prune_dict({ + "input": create_input_dict_v1([self.input]), + }) + + def validate(self): + self.input.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + real_inputs = [] + for name, input in self.cached_inputs.items(): + if name in live_inputs: + real_inputs.append(input) + add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, real_inputs, curr_prefix) + + class TemplatePrefix(_AutogrowTemplate): + def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): + super().__init__(input) + self.prefix = prefix + assert(min >= 0) + assert(max >= 1) + assert(max <= Autogrow._MaxNames) self.min = min self.max = max + self.names = [f"{self.prefix}{i}" for i in range(self.max)] + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "prefix": self.prefix, + "min": self.min, + "max": self.max, + }) + + class TemplateNames(_AutogrowTemplate): + def __init__(self, input: Input, names: list[str], min: int=1): + super().__init__(input) + self.names = names[:Autogrow._MaxNames] + assert(min >= 0) + self.min = min + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "names": self.names, + "min": self.min, + }) + + class Input(DynamicInput): + def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) def get_dynamic(self) -> list[Input]: - curr_count = 1 - new_inputs = [] - for i in range(self.min): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = self.optional or new_input.optional - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - # pretend to expand up to max - for i in range(curr_count-1, self.max): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = True - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - return new_inputs + return self.template.get_all() -@comfytype(io_type="COMFY_COMBODYNAMIC_V3") -class ComboDynamic(ComfyTypeI): - class Input(DynamicInput): - def __init__(self, id: str): - pass + def get_all(self) -> list[Input]: + return [self] + self.template.get_all() -@comfytype(io_type="COMFY_MATCHTYPE_V3") -class MatchType(ComfyTypeIO): - class Template: - def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): - self.template_id = template_id - self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + def validate(self): + self.template.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + curr_prefix = f"{curr_prefix}{self.id}." + # need to remove self from expected inputs dictionary; replaced by template inputs in frontend + for inner_dict in d.values(): + if self.id in inner_dict: + del inner_dict[self.id] + self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + +@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") +class DynamicCombo(ComfyTypeI): + Type = dict[str, Any] + + class Option: + def __init__(self, key: str, inputs: list[Input]): + self.key = key + self.inputs = inputs def as_dict(self): return { - "template_id": self.template_id, - "allowed_types": "".join(t.io_type for t in self.allowed_types), + "key": self.key, + "inputs": create_input_dict_v1(self.inputs), } class Input(DynamicInput): - def __init__(self, id: str, template: MatchType.Template, + def __init__(self, id: str, options: list[DynamicCombo.Option], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template = template + self.options = options + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) def get_dynamic(self) -> list[Input]: - return [self] + return [input for option in self.options for input in option.inputs] + + def get_all(self) -> list[Input]: + return [self] + [input for option in self.options for input in option.inputs] def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "options": [o.as_dict() for o in self.options], }) - class Output(DynamicOutput): - def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) - self.template = template + def validate(self): + # make sure all nested inputs are validated + for option in self.options: + for input in option.inputs: + input.validate() - def get_dynamic(self) -> list[Output]: - return [self] +@comfytype(io_type="COMFY_DYNAMICSLOT_V3") +class DynamicSlot(ComfyTypeI): + Type = dict[str, Any] + + class Input(DynamicInput): + def __init__(self, slot: Input, inputs: list[Input], + display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None): + assert(not isinstance(slot, DynamicInput)) + self.slot = copy.copy(slot) + self.slot.display_name = slot.display_name if slot.display_name is not None else display_name + optional = True + self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip + self.slot.lazy = slot.lazy if slot.lazy is not None else lazy + self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict + super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict) + self.inputs = inputs + self.force_input = None + # force widget inputs to have no widgets, otherwise this would be awkward + if isinstance(self.slot, WidgetInput): + self.force_input = True + self.slot.force_input = True + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) + + def get_dynamic(self) -> list[Input]: + return [self.slot] + self.inputs + + def get_all(self) -> list[Input]: + return [self] + [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "slotType": str(self.slot.get_io_type()), + "inputs": create_input_dict_v1(self.inputs), + "forceInput": self.force_input, }) + def validate(self): + self.slot.validate() + for input in self.inputs: + input.validate() + +def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): + dynamic = d.setdefault("dynamic_paths", {}) + if self is not None: + dynamic[self.id] = f"{curr_prefix}{self.id}" + for i in inputs: + if not isinstance(i, DynamicInput): + dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + +class V3Data(TypedDict): + hidden_inputs: dict[str, Any] + dynamic_paths: dict[str, Any] class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -984,6 +1163,7 @@ class NodeInfoV1: output_is_list: list[bool]=None output_name: list[str]=None output_tooltips: list[str]=None + output_matchtypes: list[str]=None name: str=None display_name: str=None description: str=None @@ -1061,7 +1241,11 @@ class Schema: '''Validate the schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' - input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + nested_inputs: list[Input] = [] + if self.inputs is not None: + for input in self.inputs: + nested_inputs.extend(input.get_all()) + input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] input_set = set(input_ids) output_set = set(output_ids) @@ -1077,6 +1261,13 @@ class Schema: issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) + # validate inputs and outputs + if self.inputs is not None: + for input in self.inputs: + input.validate() + if self.outputs is not None: + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1102,19 +1293,10 @@ class Schema: if output.id is None: output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls) -> NodeInfoV1: + def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: + # NOTE: live_inputs will not be used anymore very soon and this will be done another way # get V1 inputs - input = { - "required": {} - } - if self.inputs: - for i in self.inputs: - if isinstance(i, DynamicInput): - dynamic_inputs = i.get_dynamic() - for d in dynamic_inputs: - add_to_dict_v1(d, input) - else: - add_to_dict_v1(i, input) + input = create_input_dict_v1(self.inputs, live_inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1123,12 +1305,24 @@ class Schema: output_is_list = [] output_name = [] output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False if self.outputs: for o in self.outputs: output.append(o.io_type) output_is_list.append(o.is_output_list) output_name.append(o.display_name if o.display_name else o.io_type) output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None info = NodeInfoV1( input=input, @@ -1137,6 +1331,7 @@ class Schema: output_is_list=output_is_list, output_name=output_name, output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, name=self.node_id, display_name=self.display_name, category=self.category, @@ -1182,16 +1377,57 @@ class Schema: return info -def add_to_dict_v1(i: Input, input: dict): +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: + input = { + "required": {} + } + add_to_input_dict_v1(input, inputs, live_inputs) + return input + +def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + if dynamic_dict is None: + value = (i.get_io_type(), as_dict) + else: + value = (i.get_io_type(), as_dict, dynamic_dict) + d.setdefault(key, {})[i.id] = value def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): + paths = v3_data.get("dynamic_paths", None) + if paths is None: + return values + values = values.copy() + result = {} + + for key, path in paths.items(): + parts = path.split(".") + current = result + + for i, p in enumerate(parts): + is_last = (i == len(parts) - 1) + + if is_last: + current[p] = values.pop(key, None) + else: + current = current.setdefault(p, {}) + + values.update(result) + return values class _ComfyNodeBaseInternal(_ComfyNodeInternal): @@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) return type_clone @final @@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls) + info = schema.get_v1_info(cls, live_inputs) input = info.input if not include_hidden: input.pop("hidden", None) if return_schema: - return input, schema + v3_data: V3Data = {} + dynamic = input.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + return input, schema, v3_data return input @final @@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: + def validate_inputs(cls, **kwargs) -> bool | str: """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" raise NotImplementedError @@ -1628,6 +1868,7 @@ __all__ = [ "StyleModel", "Gligen", "UpscaleModel", + "LatentUpscaleModel", "Audio", "Video", "SVG", @@ -1651,6 +1892,10 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + # Dynamic Types + "MatchType", + # "DynamicCombo", + # "Autogrow", # Other classes "HiddenHolder", "Hidden", @@ -1661,4 +1906,5 @@ __all__ = [ "NodeOutput", "add_to_dict_v1", "add_to_dict_v3", + "V3Data", ] diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py new file mode 100644 index 000000000..43c7680f3 --- /dev/null +++ b/comfy_api/latest/_io_public.py @@ -0,0 +1 @@ +from ._io import * # noqa: F403 diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py new file mode 100644 index 000000000..85b11d78b --- /dev/null +++ b/comfy_api/latest/_ui_public.py @@ -0,0 +1 @@ +from ._ui import * # noqa: F403 diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index de0f95001..c4fa1d971 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,7 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 +from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -42,4 +42,8 @@ __all__ = [ "InputImpl", "Types", "ComfyExtension", + "io", + "IO", + "ui", + "UI", ] diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index cec105fc9..24c0b4ed7 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,5 @@ from __future__ import annotations +from comfy_api.latest import IO def validate_node_input( @@ -23,6 +24,11 @@ def validate_node_input( if not received_type != input_type: return True + # If the received type or input_type is a MatchType, we can return True immediately; + # validation for this is handled by the frontend + if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..95a6ba788 --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,155 @@ +from typing import TypedDict +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import _io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True, optional=True), + io.MatchType.Input("on_true", template=template, lazy=True, optional=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=..., on_true=...): + # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + # This trick allows us to ignore the value of the switch and still be able to run execute(). + + # One of the inputs may be missing, in which case we need to evaluate the other input + if on_false is ...: + return ["on_true"] + if on_true is ...: + return ["on_false"] + # Normal lazy switch operation + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def validate_inputs(cls, switch, on_false=..., on_true=...): + # This check happens before check_lazy_status(), so we can eliminate the case where + # both inputs are missing. + if on_false is ... and on_true is ...: + return "At least one of on_false or on_true must be connected to Switch node" + return True + + @classmethod + def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: + if on_true is ...: + return io.NodeOutput(on_false) + if on_false is ...: + return io.NodeOutput(on_true) + return io.NodeOutput(on_true if switch else on_false) + + +class DCTestNode(io.ComfyNode): + class DCValues(TypedDict): + combo: str + string: str + integer: int + image: io.Image.Type + subcombo: dict[str] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DCTestNode", + display_name="DCTest", + category="logic", + is_output_node=True, + inputs=[_io.DynamicCombo.Input("combo", options=[ + _io.DynamicCombo.Option("option1", [io.String.Input("string")]), + _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + _io.DynamicCombo.Option("option4", [ + _io.DynamicCombo.Input("subcombo", options=[ + _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + ]) + ])] + )], + outputs=[io.AnyType.Output()], + ) + + @classmethod + def execute(cls, combo: DCValues) -> io.NodeOutput: + combo_val = combo["combo"] + if combo_val == "option1": + return io.NodeOutput(combo["string"]) + elif combo_val == "option2": + return io.NodeOutput(combo["integer"]) + elif combo_val == "option3": + return io.NodeOutput(combo["image"]) + elif combo_val == "option4": + return io.NodeOutput(f"{combo['subcombo']}") + else: + raise ValueError(f"Invalid combo: {combo_val}") + + +class AutogrowNamesTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"]) + return io.Schema( + node_id="AutogrowNamesTestNode", + display_name="AutogrowNamesTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class AutogrowPrefixTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10) + return io.Schema( + node_id="AutogrowPrefixTestNode", + display_name="AutogrowPrefixTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # SwitchNode, + # DCTestNode, + # AutogrowNamesTestNode, + # AutogrowPrefixTestNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/execution.py b/execution.py index 17c77beab..c2186ac98 100644 --- a/execution.py +++ b/execution.py @@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func -from comfy_api.latest import io +from comfy_api.latest import io, _io class ExecutionResult(Enum): @@ -76,7 +76,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} if is_v3: - valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) else: valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys, hidden_inputs_v3 + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data map_node_over_list = None #Don't hook this please @@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) # otherwise, use class instance to populate/reuse some fields else: type_obj = type(obj) type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -320,8 +325,8 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated): class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) - errors = [] valid = True validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): + class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: + class_inputs = obj_class.INPUT_TYPES() validate_function_name = "VALIDATE_INPUTS" validate_function = getattr(obj_class, validate_function_name, None) if validate_function is not None: @@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated): validate_has_kwargs = argspec.varkw is not None received_types = {} + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + for x in valid_inputs: input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None @@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): diff --git a/nodes.py b/nodes.py index 4c910a34b..356aa63df 100644 --- a/nodes.py +++ b/nodes.py @@ -2355,6 +2355,7 @@ async def init_builtin_extra_nodes(): "nodes_easycache.py", "nodes_audio_encoder.py", "nodes_rope.py", + "nodes_logic.py", "nodes_nop.py", ] From 861817d22d2659099811b56005c9eaea18d64c73 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:47:51 -0800 Subject: [PATCH 17/29] Fix issue with portable updater. (#11070) This should fix the problem with the portable updater not working with portables created from a separate branch on the repo. This does not affect any current portables who were all created on the master branch. --- .ci/update_windows/update.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 51a263203..59ece5130 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -66,8 +66,10 @@ if branch is None: try: ref = repo.lookup_reference('refs/remotes/origin/master') except: - print("pulling.") # noqa: T201 - pull(repo) + print("fetching.") # noqa: T201 + for remote in repo.remotes: + if remote.name == "origin": + remote.fetch() ref = repo.lookup_reference('refs/remotes/origin/master') repo.checkout(ref) branch = repo.lookup_branch('master') @@ -149,3 +151,4 @@ try: shutil.copy(stable_update_script, stable_update_script_to) except: pass + From 519c9411653df99761053c30e101816e0ca3c24b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:28:45 +1000 Subject: [PATCH 18/29] Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069) * mp: only count the offload cost of math once This was previously bundling the combined weight storage and computation cost * ops: put all post async transfer compute on the main stream Some models have massive weights that need either complex dequantization or lora patching. Don't do these patchings on the offload stream, instead do them on the main stream to syncrhonize the potentially large vram spikes for these compute processes. This avoids having to assume a worst case scenario of multiple offload streams all spiking VRAM is parallel with whatever the main stream is doing. --- comfy/model_patcher.py | 4 ++-- comfy/ops.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: From 19f2192d69d13445131b72ad1d87167f59b66fc4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:37:35 +0200 Subject: [PATCH 19/29] fix(V3-Schema): use empty list defaults for Schema.inputs/outputs/hidden to avoid None issues (#11083) --- comfy_api/latest/_io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 257f07c42..866c3e0eb 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -5,7 +5,7 @@ import inspect from abc import ABC, abstractmethod from collections import Counter from collections.abc import Iterable -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing_extensions import NotRequired, final @@ -1199,9 +1199,9 @@ class Schema: """Display name of node.""" category: str = "sd" """The category of the node, as per the "Add Node" menu.""" - inputs: list[Input]=None - outputs: list[Output]=None - hidden: list[Hidden]=None + inputs: list[Input] = field(default_factory=list) + outputs: list[Output] = field(default_factory=list) + hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False From 87c104bfc1928f0b018a50f5867f425e10482929 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:55:44 +0200 Subject: [PATCH 20/29] add support for "@image" reference format in Kling Omni API nodes (#11082) --- comfy_api_nodes/apis/kling_api.py | 30 +++++-- comfy_api_nodes/nodes_kling.py | 138 ++++++++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 0a3b447c5..d8949f8ac 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel): url: str | None = Field(None, description="URL for generated video") -class TaskStatusVideoResults(BaseModel): +class TaskStatusImageResult(BaseModel): + index: int = Field(..., description="Image Number,0-9") + url: str = Field(..., description="URL for generated image") + + +class OmniTaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) + images: list[TaskStatusImageResult] | None = Field(None) -class TaskStatusVideoResponseData(BaseModel): +class OmniTaskStatusResponseData(BaseModel): created_at: int | None = Field(None, description="Task creation time") updated_at: int | None = Field(None, description="Task update time") task_status: str | None = None task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_id: str | None = Field(None, description="Task ID") - task_result: TaskStatusVideoResults | None = Field(None) + task_result: OmniTaskStatusResults | None = Field(None) -class TaskStatusVideoResponse(BaseModel): +class OmniTaskStatusResponse(BaseModel): code: int | None = Field(None, description="Error code") message: str | None = Field(None, description="Error message") request_id: str | None = Field(None, description="Request ID") - data: TaskStatusVideoResponseData | None = Field(None) + data: OmniTaskStatusResponseData | None = Field(None) + + +class OmniImageParamImage(BaseModel): + image: str = Field(...) + + +class OmniProImageRequest(BaseModel): + model_name: str = Field(..., description="kling-image-o1") + resolution: str = Field(..., description="'1k' or '2k'") + aspect_ratio: str | None = Field(...) + prompt: str = Field(...) + mode: str = Field("pro") + n: int | None = Field(1, le=9) + image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 850c44db6..6c840dc47 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere import logging import math +import re import torch from typing_extensions import override @@ -49,12 +50,14 @@ from comfy_api_nodes.apis import ( KlingSingleImageEffectModelName, ) from comfy_api_nodes.apis.kling_api import ( + OmniImageParamImage, OmniParamImage, OmniParamVideo, OmniProFirstLastFrameRequest, + OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, - TaskStatusVideoResponse, + OmniTaskStatusResponse, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -210,7 +213,36 @@ VOICES_CONFIG = { } -async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: +def normalize_omni_prompt_references(prompt: str) -> str: + """ + Rewrites Kling Omni-style placeholders used in the app, like: + + @image, @image1, @image2, ... @imageN + @video, @video1, @video2, ... @videoN + + into the API-compatible form: + + <<>>, <<>>, ... + <<>>, <<>>, ... + + This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app. + """ + if not prompt: + return prompt + + def _image_repl(match): + return f"<<>>" + + def _video_repl(match): + return f"<<>>" + + # (? and not @imageFoo + prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt) + return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) + + +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: if response.code: raise RuntimeError( f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" @@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProText2VideoRequest( model_name=model_name, prompt=prompt, @@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") @@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProFirstLastFrameRequest( model_name=model_name, prompt=prompt, @@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): duration: int, reference_images: Input.Image, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if get_number_of_images(reference_images) > 7: raise ValueError("The maximum number of reference images is 7.") @@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode): return await finish_omni_video_task(cls, response) +class OmniProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageNode", + display_name="Kling Omni Image (Pro)", + category="api node/image/Kling", + description="Create or edit images with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the image content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], + ), + IO.Image.Input( + "reference_images", + tooltip="Up to 10 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + resolution: str, + aspect_ratio: str, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + image_list: list[OmniImageParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 10: + raise ValueError("The maximum number of reference images is 10.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniImageParamImage(image=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), + response_model=OmniTaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension): OmniProImageToVideoNode, OmniProVideoToVideoNode, OmniProEditVideoNode, + # OmniProImageNode, # need support from backend ] From 440268d3940eb14a01595439bbc05c4aacde9c72 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 23:52:31 +0200 Subject: [PATCH 21/29] convert nodes_load_3d.py to V3 schema (#10990) --- comfy_api/latest/_ui.py | 13 +++- comfy_extras/nodes_load_3d.py | 127 ++++++++++++++++------------------ 2 files changed, 71 insertions(+), 69 deletions(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index b0bbabe2a..6d1bea599 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import os import random +import uuid from io import BytesIO from typing import Type @@ -436,9 +437,19 @@ class PreviewUI3D(_UIOutput): def __init__(self, model_file, camera_info, **kwargs): self.model_file = model_file self.camera_info = camera_info + self.bg_image_path = None + bg_image = kwargs.get("bg_image", None) + if bg_image is not None: + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = PILImage.fromarray(img_array) + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + self.bg_image_path = f"temp/{filename}" def as_dict(self): - return {"result": [self.model_file, self.camera_info]} + return {"result": [self.model_file, self.camera_info, self.bg_image_path]} class PreviewText(_UIOutput): diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 54c66ef68..545588ef8 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -2,22 +2,18 @@ import nodes import folder_paths import os -from comfy.comfy_types import IO -from comfy_api.input_impl import VideoFromFile +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, InputImpl, UI from pathlib import Path -from PIL import Image -import numpy as np - -import uuid def normalize_path(path): return path.replace('\\', '/') -class Load3D(): +class Load3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = os.path.join(folder_paths.get_input_directory(), "3d") os.makedirs(input_dir, exist_ok=True) @@ -30,23 +26,29 @@ class Load3D(): for file_path in input_path.rglob("*") if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} ] + return IO.Schema( + node_id="Load3D", + display_name="Load 3D & Animation", + category="3d", + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("image"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output(display_name="mask"), + IO.String.Output(display_name="mesh_path"), + IO.Image.Output(display_name="normal"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Video.Output(display_name="recording_video"), + ], + ) - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): + @classmethod + def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) normal_path = folder_paths.get_annotated_filepath(image['normal']) @@ -61,58 +63,47 @@ class Load3D(): if image['recording'] != "": recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - video = VideoFromFile(recording_video_path) + video = InputImpl.VideoFromFile(recording_video_path) - return output_image, output_mask, model_file, normal_image, image['camera_info'], video + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) -class Preview3D(): + process = execute # TODO: remove + + +class Preview3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}), - "bg_image": ("IMAGE", {}) - }} + def define_schema(cls): + return IO.Schema( + node_id="Preview3D", + display_name="Preview 3D & Animation", + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.String.Input("model_file", default="", multiline=False), + IO.Load3DCamera.Input("camera_info", optional=True), + IO.Image.Input("bg_image", optional=True), + ], + outputs=[], + ) - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): + @classmethod + def execute(cls, model_file, **kwargs) -> IO.NodeOutput: camera_info = kwargs.get("camera_info", None) bg_image = kwargs.get("bg_image", None) + return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) - bg_image_path = None - if bg_image is not None: + process = execute # TODO: remove - img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) - img = Image.fromarray(img_array) - temp_dir = folder_paths.get_temp_directory() - filename = f"bg_{uuid.uuid4().hex}.png" - bg_image_path = os.path.join(temp_dir, filename) - img.save(bg_image_path, compress_level=1) +class Load3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Load3D, + Preview3D, + ] - bg_image_path = f"temp/{filename}" - return { - "ui": { - "result": [model_file, camera_info, bg_image_path] - } - } - -NODE_CLASS_MAPPINGS = { - "Load3D": Load3D, - "Preview3D": Preview3D, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D & Animation", - "Preview3D": "Preview 3D & Animation", -} +async def comfy_entrypoint() -> Load3DExtension: + return Load3DExtension() From dce518c2b4f99634b5fdde1924d9b0bd468fe1ce Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 4 Dec 2025 03:35:04 +0200 Subject: [PATCH 22/29] convert nodes_audio.py to V3 schema (#10798) --- comfy_api/latest/_ui.py | 9 +- comfy_extras/nodes_audio.py | 744 ++++++++++++++++++------------------ 2 files changed, 382 insertions(+), 371 deletions(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6d1bea599..5a75a3aae 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -319,9 +319,10 @@ class AudioSaveHelper: for key, value in metadata.items(): output_container.metadata[key] = value + layout = "mono" if waveform.shape[0] == 1 else "stereo" # Set up the output stream with appropriate properties if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) if quality == "64k": out_stream.bit_rate = 64000 elif quality == "96k": @@ -333,7 +334,7 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) if quality == "V0": # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool out_stream.codec_context.qscale = 1 @@ -342,12 +343,12 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 else: # format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) frame = av.AudioFrame.from_ndarray( waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format="flt", - layout="mono" if waveform.shape[0] == 1 else "stereo", + layout=layout, ) frame.sample_rate = sample_rate frame.pts = 0 diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 2ed7e0b22..812301fb7 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -6,65 +6,80 @@ import torch import comfy.model_management import folder_paths import os -import io -import json -import random import hashlib import node_helpers import logging -from comfy.cli_args import args -from comfy.comfy_types import FileLocator +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI -class EmptyLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples":latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples":latent, "type": "audio"}) -class ConditioningStableAudio: + generate = execute # TODO: remove + + +class ConditioningStableAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - }} + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "append" - - CATEGORY = "conditioning" - - def append(self, positive, negative, seconds_start, seconds_total): + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) - return (positive, negative) + return IO.NodeOutput(positive, negative) -class VAEEncodeAudio: + append = execute # TODO: remove + + +class VAEEncodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent/audio" - - def encode(self, vae, audio): + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] if 44100 != sample_rate: waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) @@ -72,213 +87,134 @@ class VAEEncodeAudio: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples":t}, ) + return IO.NodeOutput({"samples":t}) -class VAEDecodeAudio: + encode = execute # TODO: remove + + +class VAEDecodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} - RETURN_TYPES = ("AUDIO",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "latent/audio" - - def decode(self, vae, samples): + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100}, ) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + + decode = execute # TODO: remove -def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): - - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results: list[FileLocator] = [] - - # Prepare metadata dictionary - metadata = {} - if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - # Opus supported sample rates - OPUS_RATES = [8000, 12000, 16000, 24000, 48000] - - for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" - output_path = os.path.join(full_output_folder, file) - - # Use original sample rate initially - sample_rate = audio["sample_rate"] - - # Handle Opus sample rate requirements - if format == "opus": - if sample_rate > 48000: - sample_rate = 48000 - elif sample_rate not in OPUS_RATES: - # Find the next highest supported rate - for rate in sorted(OPUS_RATES): - if rate > sample_rate: - sample_rate = rate - break - if sample_rate not in OPUS_RATES: # Fallback if still not supported - sample_rate = 48000 - - # Resample if necessary - if sample_rate != audio["sample_rate"]: - waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) - - # Create output with specified format - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format=format) - - # Set metadata on the container - for key, value in metadata.items(): - output_container.metadata[key] = value - - layout = 'mono' if waveform.shape[0] == 1 else 'stereo' - # Set up the output stream with appropriate properties - if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) - if quality == "64k": - out_stream.bit_rate = 64000 - elif quality == "96k": - out_stream.bit_rate = 96000 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "192k": - out_stream.bit_rate = 192000 - elif quality == "320k": - out_stream.bit_rate = 320000 - elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) - if quality == "V0": - #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool - out_stream.codec_context.qscale = 1 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "320k": - out_stream.bit_rate = 320000 - else: #format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout) - frame.sample_rate = sample_rate - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - - # Flush encoder - output_container.mux(out_stream.encode(None)) - - # Close containers - output_container.close() - - # Write the output to file - output_buffer.seek(0) - with open(output_path, 'wb') as f: - f.write(output_buffer.getbuffer()) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - return { "ui": { "audio": results } } - -class SaveAudio: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + display_name="Save Audio (FLAC)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) - RETURN_TYPES = () - FUNCTION = "save_flac" + save_flac = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) - -class SaveAudioMP3: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + display_name="Save Audio (MP3)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["V0", "128k", "320k"], {"default": "V0"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_mp3" + save_mp3 = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class SaveAudioOpus: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_opus" + save_opus = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class PreviewAudio(SaveAudio): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"audio": ("AUDIO", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + + save_flac = execute # TODO: remove + def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format.""" @@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]: wav = f32_pcm(wav) return wav, sr -class LoadAudio: +class LoadAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + return IO.Schema( + node_id="LoadAudio", + display_name="Load Audio", + category="audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) @classmethod - def IS_CHANGED(s, audio): + def fingerprint_inputs(cls, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -343,46 +283,69 @@ class LoadAudio: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): + def validate_inputs(cls, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True -class RecordAudio: + load = execute # TODO: remove + + +class RecordAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO_RECORD", {})}} + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) + + load = execute # TODO: remove -class TrimAudioDuration: +class TrimAudioDuration(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio": ("AUDIO",), - "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "trim" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Trim audio tensor into chosen time range." - - def trim(self, audio, start_index, duration): + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] audio_length = waveform.shape[-1] @@ -399,23 +362,30 @@ class TrimAudioDuration: if start_frame >= end_frame: raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") - return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove -class SplitAudioChannels: +class SplitAudioChannels(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - }} + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) - RETURN_TYPES = ("AUDIO", "AUDIO") - RETURN_NAMES = ("left", "right") - FUNCTION = "separate" - CATEGORY = "audio" - DESCRIPTION = "Separates the audio into left and right channels." - - def separate(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] @@ -425,7 +395,9 @@ class SplitAudioChannels: left_channel = waveform[..., 0:1, :] right_channel = waveform[..., 1:2, :] - return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): @@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_ return waveform_1, waveform_2, output_sample_rate -class AudioConcat: +class AudioConcat(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "concat" - CATEGORY = "audio" - DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." - - def concat(self, audio1, audio2, direction): + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -477,26 +457,33 @@ class AudioConcat: elif direction == 'before': concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) - return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove -class AudioMerge: +class AudioMerge(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "merge" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." - - def merge(self, audio1, audio2, merge_method): + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -530,85 +517,108 @@ class AudioMerge: if max_val > 1.0: waveform = waveform / max_val - return ({"waveform": waveform, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove -class AudioAdjustVolume: +class AudioAdjustVolume(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "adjust_volume" - CATEGORY = "audio" - - def adjust_volume(self, audio, volume): + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: if volume == 0: - return (audio,) + return IO.NodeOutput(audio) waveform = audio["waveform"] sample_rate = audio["sample_rate"] gain = 10 ** (volume / 20) waveform = waveform * gain - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove -class EmptyAudio: +class EmptyAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), - "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), - "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Float.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + ), + IO.Float.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + ), + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "create_empty_audio" - CATEGORY = "audio" - - def create_empty_audio(self, duration, sample_rate, channels): + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: num_samples = int(round(duration * sample_rate)) waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "EmptyLatentAudio": EmptyLatentAudio, - "VAEEncodeAudio": VAEEncodeAudio, - "VAEDecodeAudio": VAEDecodeAudio, - "SaveAudio": SaveAudio, - "SaveAudioMP3": SaveAudioMP3, - "SaveAudioOpus": SaveAudioOpus, - "LoadAudio": LoadAudio, - "PreviewAudio": PreviewAudio, - "ConditioningStableAudio": ConditioningStableAudio, - "RecordAudio": RecordAudio, - "TrimAudioDuration": TrimAudioDuration, - "SplitAudioChannels": SplitAudioChannels, - "AudioConcat": AudioConcat, - "AudioMerge": AudioMerge, - "AudioAdjustVolume": AudioAdjustVolume, - "EmptyAudio": EmptyAudio, -} +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "EmptyLatentAudio": "Empty Latent Audio", - "VAEEncodeAudio": "VAE Encode Audio", - "VAEDecodeAudio": "VAE Decode Audio", - "PreviewAudio": "Preview Audio", - "LoadAudio": "Load Audio", - "SaveAudio": "Save Audio (FLAC)", - "SaveAudioMP3": "Save Audio (MP3)", - "SaveAudioOpus": "Save Audio (Opus)", - "RecordAudio": "Record Audio", - "TrimAudioDuration": "Trim Audio Duration", - "SplitAudioChannels": "Split Audio Channels", - "AudioConcat": "Audio Concat", - "AudioMerge": "Audio Merge", - "AudioAdjustVolume": "Audio Adjust Volume", - "EmptyAudio": "Empty Audio", -} +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() From ecdc8697d53919a9178bf53ef327a110582db8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:49:28 -0800 Subject: [PATCH 23/29] Qwen Image Lora training fix from #11090 (#11094) --- comfy_extras/nodes_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index cb24ab709..19b8baaf4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode): noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents, From ea17add3c62197b10fd0b71d9169d339adc55c47 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:15:15 -0800 Subject: [PATCH 24/29] Fix case where text encoders where running on the CPU instead of GPU. (#11095) --- comfy/sd.py | 2 ++ comfy/sd1_clip.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5..734bd2845 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -193,6 +193,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -240,6 +241,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3db..503a51843 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: @@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None From 6be85c7920224b45bbc6417e00147815e78c12a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:28:44 +1000 Subject: [PATCH 25/29] mp: use look-ahead actuals for stream offload VRAM calculation (#11096) TIL that the WAN TE has a 2GB weight followed by 16MB as the next size down. This means that team 8GB VRAM would fully offload the TE in async offload mode as it just multiplied this giant size my the num streams. Do the more complex logic of summing up the upcoming to-load weight sizes to avoid triple counting this massive weight. partial unload does the converse of recording the NS most recent unloads as they go. --- comfy/model_patcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index df2d8e827..3dcac3eef 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,12 +699,12 @@ class ModelPatcher: offloaded = [] offload_buffer = 0 loading.sort(reverse=True) - for x in loading: + for i, x in enumerate(loading): module_offload_mem, module_mem, n, m, params = x lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) + potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]])) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -876,14 +876,18 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + if len(unload_list) > 0: + NS = comfy.model_management.NUM_STREAMS + offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -935,6 +939,8 @@ class ModelPatcher: m.comfy_patched_weights = False memory_freed += module_mem offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) logging.debug("freed {}".format(n)) for param in params: From f4bdf5f8302ef10db99644a8672e614ddb29c473 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:04 +1000 Subject: [PATCH 26/29] sd: revise hy VAE VRAM (#11105) This was recently collapsed down to rolling VAE through temporal. Clamp The time dimension. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 734bd2845..fe4dd65f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -483,8 +483,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + #This is likely to significantly over-estimate with single image or low frame counts as the + #implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) From 9bc893c5bbd2838bdd15ebd40e3a3e548ce3e4f0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:36 +1000 Subject: [PATCH 27/29] sd: bump HY1.5 VAE estimate (#11107) Im able to push vram above estimate on partial unload. Bump the estimate. This is experimentally determined with a 720P and 480P datapoint calibrating for 24GB VRAM total. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index fe4dd65f8..03bdb33d5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -471,7 +471,7 @@ class VAE: decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True From 3c8456223c5f6a41af7d99219b391c8c58acb552 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 5 Dec 2025 00:05:28 +0200 Subject: [PATCH 28/29] [API Nodes]: fixes and refactor (#11104) * chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder * chore(api-nodes): add validate_video_frame_count function from LTX PR * chore(api-nodes): replace deprecated V1 imports * fix(api-nodes): the types returned by the "poll_op" function are now correct. --- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/_helpers.py | 14 +-- comfy_api_nodes/util/client.py | 145 ++++++++++++----------- comfy_api_nodes/util/conversions.py | 21 ++-- comfy_api_nodes/util/download_helpers.py | 20 ++-- comfy_api_nodes/util/request_logger.py | 2 - comfy_api_nodes/util/upload_helpers.py | 16 ++- comfy_api_nodes/util/validation_utils.py | 61 ++++++---- 8 files changed, 146 insertions(+), 135 deletions(-) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 80292fb3c..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -47,6 +47,7 @@ from .validation_utils import ( validate_string, validate_video_dimensions, validate_video_duration, + validate_video_frame_count, ) __all__ = [ @@ -94,6 +95,7 @@ __all__ = [ "validate_string", "validate_video_dimensions", "validate_video_duration", + "validate_video_frame_count", # Misc functions "get_fs_object_size", ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 328fe5227..491e6b6a8 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -2,8 +2,8 @@ import asyncio import contextlib import os import time +from collections.abc import Callable from io import BytesIO -from typing import Callable, Optional, Union from comfy.cli_args import args from comfy.model_management import processing_interrupted @@ -35,12 +35,12 @@ def default_base_url() -> str: async def sleep_with_interrupt( seconds: float, - node_cls: Optional[type[IO.ComfyNode]], - label: Optional[str] = None, - start_ts: Optional[float] = None, - estimated_total: Optional[int] = None, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, *, - display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, ): """ Sleep in 1s slices while: @@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: +def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf01d7d36..bf37cba5f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -4,10 +4,11 @@ import json import logging import time import uuid +from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Literal, TypeVar from urllib.parse import urljoin, urlparse import aiohttp @@ -37,8 +38,8 @@ class ApiEndpoint: path: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", *, - query_params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, ): self.path = path self.method = method @@ -52,18 +53,18 @@ class _RequestConfig: endpoint: ApiEndpoint timeout: float content_type: str - data: Optional[dict[str, Any]] - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] - multipart_parser: Optional[Callable] + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None max_retries: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" monitor_progress: bool = True - estimated_total: Optional[int] = None - final_label_on_success: Optional[str] = "Completed" - progress_origin_ts: Optional[float] = None - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None @dataclass @@ -71,10 +72,10 @@ class _PollUIState: started: float status_label: str = "Queued" is_queued: bool = True - price: Optional[float] = None - estimated_duration: Optional[int] = None + price: float | None = None + estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: Optional[float] = None # start time of current active interval (None if queued) + active_since: float | None = None # start time of current active interval (None if queued) _RETRY_STATUS = {408, 429, 500, 502, 503, 504} @@ -87,20 +88,20 @@ async def sync_op( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - response_model: Type[M], - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - data: Optional[BaseModel] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, ) -> M: raw = await sync_op_raw( @@ -131,22 +132,22 @@ async def poll_op( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - response_model: Type[M], - status_extractor: Callable[[M], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[M], Optional[int]]] = None, - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[BaseModel] = None, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> M: raw = await poll_op_raw( @@ -178,22 +179,22 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, + estimated_duration: int | None = None, as_binary: bool = False, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, -) -> Union[dict[str, Any], bytes]: +) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). @@ -229,21 +230,21 @@ async def poll_op_raw( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> dict[str, Any]: """ @@ -261,7 +262,7 @@ async def poll_op_raw( consumed_attempts = 0 # counts only non-queued polls progress_bar = utils.ProgressBar(100) if progress_extractor else None - last_progress: Optional[int] = None + last_progress: int | None = None state = _PollUIState(started=started, estimated_duration=estimated_duration) stop_ticker = asyncio.Event() @@ -420,10 +421,10 @@ async def poll_op_raw( def _display_text( node_cls: type[IO.ComfyNode], - text: Optional[str], + text: str | None, *, - status: Optional[Union[str, int]] = None, - price: Optional[float] = None, + status: str | int | None = None, + price: float | None = None, ) -> None: display_lines: list[str] = [] if status: @@ -440,13 +441,13 @@ def _display_text( def _display_time_progress( node_cls: type[IO.ComfyNode], - status: Optional[Union[str, int]], + status: str | int | None, elapsed_seconds: int, - estimated_total: Optional[int] = None, + estimated_total: int | None = None, *, - price: Optional[float] = None, - is_queued: Optional[bool] = None, - processing_elapsed_seconds: Optional[int] = None, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str: def _snapshot_request_body_for_logging( content_type: str, method: str, - data: Optional[dict[str, Any]], - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], -) -> Optional[Union[dict[str, Any], str]]: + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: if method.upper() == "GET": return None if content_type == "multipart/form-data": @@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): attempt = 0 delay = cfg.retry_delay operation_succeeded: bool = False - final_elapsed_seconds: Optional[int] = None - extracted_price: Optional[float] = None + final_elapsed_seconds: int | None = None + extracted_price: float | None = None while True: attempt += 1 stop_event = asyncio.Event() - monitor_task: Optional[asyncio.Task] = None - sess: Optional[aiohttp.ClientSession] = None + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) @@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): ) -def _validate_or_raise(response_model: Type[M], payload: Any) -> M: +def _validate_or_raise(response_model: type[M], payload: Any) -> M: try: return response_model.model_validate(payload) except Exception as e: @@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _wrap_model_extractor( - response_model: Type[M], - extractor: Optional[Callable[[M], Any]], -) -> Optional[Callable[[dict[str, Any]], Any]]: + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: """Wrap a typed extractor so it can be used by the dict-based poller. Validates the dict into `response_model` before invoking `extractor`. Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating @@ -929,10 +930,10 @@ def _wrap_model_extractor( return _wrapped -def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: if not values: return set() - out: set[Union[str, int]] = set() + out: set[str | int] = set() for v in values: nv = _normalize_status_value(v) if nv is not None: @@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio return out -def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: +def _normalize_status_value(val: str | int | None) -> str | int | None: if isinstance(val, str): return val.strip().lower() return val diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 971dc57de..c57457580 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -4,7 +4,6 @@ import math import mimetypes import uuid from io import BytesIO -from typing import Optional import av import numpy as np @@ -12,8 +11,7 @@ import torch from PIL import Image from comfy.utils import common_upscale -from comfy_api.latest import Input, InputImpl -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import Input, InputImpl, Types from ._helpers import mimetype_to_extension @@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: Optional[str] = None, + name: str | None = None, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co def video_to_base64_string( video: Input.Video, - container_format: VideoContainer = None, - codec: VideoCodec = None + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, ) -> str: """ Converts a video input to a base64 string. @@ -189,12 +187,11 @@ def video_to_base64_string( codec: Optional codec to use (defaults to video.codec if available) """ video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) video_bytes_io.seek(0) return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 14207dc68..3e0d0352d 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -3,15 +3,15 @@ import contextlib import uuid from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO from urllib.parse import urljoin, urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError -from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl from . import request_logger from ._helpers import ( @@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + dest: BytesIO | IO[bytes] | str | Path | None, *, - timeout: Optional[float] = None, + timeout: float | None = None, max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, @@ -71,10 +71,10 @@ async def download_url_to_bytesio( is_path_sink = isinstance(dest, (str, Path)) fhandle = None - session: Optional[aiohttp.ClientSession] = None - stop_evt: Optional[asyncio.Event] = None - monitor_task: Optional[asyncio.Task] = None - req_task: Optional[asyncio.Task] = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None try: with contextlib.suppress(Exception): @@ -234,11 +234,11 @@ async def download_url_to_video_output( timeout: float = None, max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, -) -> VideoFromFile: +) -> InputImpl.VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) - return VideoFromFile(result) + return InputImpl.VideoFromFile(result) async def download_url_as_bytesio( diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index ac52e2eab..e0cb4428d 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import datetime import hashlib import json diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 0532bea9a..b8d33f4d1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,15 +4,13 @@ import logging import time import uuid from io import BytesIO -from typing import Optional from urllib.parse import urlparse import aiohttp import torch from pydantic import BaseModel, Field -from comfy_api.latest import IO, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, Input, Types from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt @@ -32,7 +30,7 @@ from .conversions import ( class UploadRequest(BaseModel): file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( + content_type: str | None = Field( None, description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", ) @@ -56,7 +54,7 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ - # if batch, try to upload each file if max_images is greater than 0 + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 @@ -100,9 +98,9 @@ async def upload_video_to_comfyapi( cls: type[IO.ComfyNode], video: Input.Video, *, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, wait_label: str | None = "Uploading", ) -> str: """ @@ -220,7 +218,7 @@ async def upload_file( return monitor_task = asyncio.create_task(_monitor()) - sess: Optional[aiohttp.ClientSession] = None + sess: aiohttp.ClientSession | None = None try: try: request_logger.log_request_response( diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ec7006aed..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,9 +1,7 @@ import logging -from typing import Optional import torch -from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -37,8 +35,8 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: @@ -54,8 +52,8 @@ def validate_image_aspect_ratio( def validate_images_aspect_ratio_closeness( first_image: torch.Tensor, second_image: torch.Tensor, - min_rel: float, # e.g. 0.8 - max_rel: float, # e.g. 1.25 + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness( def validate_aspect_ratio_string( aspect_ratio: str, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -97,10 +95,10 @@ def validate_aspect_ratio_string( def validate_video_dimensions( video: Input.Video, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -120,8 +118,8 @@ def validate_video_dimensions( def validate_video_duration( video: Input.Video, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -136,6 +134,23 @@ def validate_video_duration( raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") + + def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 @@ -144,8 +159,8 @@ def get_number_of_images(images): def validate_audio_duration( audio: Input.Audio, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ) -> None: sr = int(audio["sample_rate"]) dur = int(audio["waveform"].shape[-1]) / sr @@ -177,7 +192,7 @@ def validate_string( ) -def validate_container_format_is_mp4(video: VideoInput) -> None: +def validate_container_format_is_mp4(video: Input.Video) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: @@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float: def _assert_ratio_bounds( ar: float, *, - min_ratio: Optional[tuple[float, float]] = None, - max_ratio: Optional[tuple[float, float]] = None, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, strict: bool = True, ) -> None: """Validate a numeric aspect ratio against optional min/max ratio bounds.""" From 35fa091340c60612dfb71cb6822dc23b99a5dac2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:52:09 -0800 Subject: [PATCH 29/29] Forgot to put this in README. (#11112) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 91fb510e1..ed857df9f 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) + - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)