From 6968a70e603a0d2c80387aa139b06276635b36e3 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:53:08 +0300 Subject: [PATCH 001/102] [Partner Nodes] HappyHorse model (#13582) * feat(api-nodes): add nodes for HappyHorse model Signed-off-by: bigcat88 * fix price badges Signed-off-by: bigcat88 * fix: allow durations up to 15 s Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/wan.py | 4 +- comfy_api_nodes/nodes_wan.py | 555 +++++++++++++++++++++++++++++++++++ 2 files changed, 557 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/apis/wan.py b/comfy_api_nodes/apis/wan.py index 44b65e4f6..c64acae97 100644 --- a/comfy_api_nodes/apis/wan.py +++ b/comfy_api_nodes/apis/wan.py @@ -118,7 +118,7 @@ class Wan27ReferenceVideoInputField(BaseModel): class Wan27ReferenceVideoParametersField(BaseModel): resolution: str = Field(...) ratio: str | None = Field(None) - duration: int = Field(5, ge=2, le=10) + duration: int = Field(5, ge=2, le=15) watermark: bool = Field(False) seed: int = Field(..., ge=0, le=2147483647) @@ -157,7 +157,7 @@ class Wan27VideoEditInputField(BaseModel): class Wan27VideoEditParametersField(BaseModel): resolution: str = Field(...) ratio: str | None = Field(None) - duration: int = Field(0) + duration: int | None = Field(0) audio_setting: str = Field("auto") watermark: bool = Field(False) seed: int = Field(..., ge=0, le=2147483647) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index d1470894a..7d7466fb6 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1646,6 +1646,557 @@ class Wan2ReferenceVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class HappyHorseTextToVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HappyHorseTextToVideoApi", + display_name="HappyHorse Text to Video", + category="api node/video/Wan", + description="Generates a video based on a text prompt using the HappyHorse model.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "happyhorse-1.0-t2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4"], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $pps := $lookup($ppsTable, $res); + { "type": "usd", "usd": $pps * $dur } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + seed: int, + watermark: bool, + ): + validate_string(model["prompt"], strip_whitespace=False, min_length=1) + initial_response = await sync_op( + cls, + ApiEndpoint( + path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + method="POST", + ), + response_model=TaskCreationResponse, + data=Wan27Text2VideoTaskCreationRequest( + model=model["model"], + input=Text2VideoInputField( + prompt=model["prompt"], + negative_prompt=None, + ), + parameters=Wan27Text2VideoParametersField( + resolution=model["resolution"], + ratio=model["ratio"], + duration=model["duration"], + seed=seed, + watermark=watermark, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=7, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class HappyHorseImageToVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HappyHorseImageToVideoApi", + display_name="HappyHorse Image to Video", + category="api node/video/Wan", + description="Generate a video from a first-frame image using the HappyHorse model.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "happyhorse-1.0-i2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + ], + ), + IO.Image.Input( + "first_frame", + tooltip="First frame image. The output aspect ratio is derived from this image.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $pps := $lookup($ppsTable, $res); + { "type": "usd", "usd": $pps * $dur } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + first_frame: Input.Image, + seed: int, + watermark: bool, + ): + media = [ + Wan27MediaItem( + type="first_frame", + url=await upload_image_to_comfyapi(cls, image=first_frame), + ) + ] + initial_response = await sync_op( + cls, + ApiEndpoint( + path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + method="POST", + ), + response_model=TaskCreationResponse, + data=Wan27ImageToVideoTaskCreationRequest( + model=model["model"], + input=Wan27ImageToVideoInputField( + prompt=model["prompt"] or None, + negative_prompt=None, + media=media, + ), + parameters=Wan27ImageToVideoParametersField( + resolution=model["resolution"], + duration=model["duration"], + seed=seed, + watermark=watermark, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=7, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class HappyHorseVideoEditApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HappyHorseVideoEditApi", + display_name="HappyHorse Video Edit", + category="api node/video/Wan", + description="Edit a video using text instructions or reference images with the HappyHorse model. " + "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "happyhorse-1.0-video-edit", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Editing instructions or style transfer requirements.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4"], + tooltip="Aspect ratio. If not changed, approximates the input video ratio.", + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_image"), + names=[ + "image1", + "image2", + "image3", + "image4", + "image5", + ], + min=0, + ), + ), + ], + ), + ], + ), + IO.Video.Input( + "video", + tooltip="The video to edit.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $pps := $lookup($ppsTable, $res); + { "type": "usd", "usd": $pps, "format": { "suffix": "/second" } } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + video: Input.Video, + seed: int, + watermark: bool, + ): + validate_string(model["prompt"], strip_whitespace=False, min_length=1) + validate_video_duration(video, min_duration=3, max_duration=60) + media = [Wan27MediaItem(type="video", url=await upload_video_to_comfyapi(cls, video))] + reference_images = model.get("reference_images", {}) + for key in reference_images: + media.append( + Wan27MediaItem( + type="reference_image", url=await upload_image_to_comfyapi(cls, image=reference_images[key]) + ) + ) + initial_response = await sync_op( + cls, + ApiEndpoint( + path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + method="POST", + ), + response_model=TaskCreationResponse, + data=Wan27VideoEditTaskCreationRequest( + model=model["model"], + input=Wan27VideoEditInputField(prompt=model["prompt"], media=media), + parameters=Wan27VideoEditParametersField( + resolution=model["resolution"], + ratio=model["ratio"], + duration=None, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=7, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class HappyHorseReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HappyHorseReferenceVideoApi", + display_name="HappyHorse Reference to Video", + category="api node/video/Wan", + description="Generate a video featuring a person or object from reference materials with the HappyHorse " + "model. Supports single-character performances and multi-character interactions.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "happyhorse-1.0-r2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the video. Use identifiers such as 'character1' and " + "'character2' to refer to the reference characters.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4"], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_image"), + names=[ + "image1", + "image2", + "image3", + "image4", + "image5", + "image6", + "image7", + "image8", + "image9", + ], + min=1, + ), + ), + ], + ), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $pps := $lookup($ppsTable, $res); + { "type": "usd", "usd": $pps * $dur } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + seed: int, + watermark: bool, + ): + validate_string(model["prompt"], strip_whitespace=False, min_length=1) + media = [] + reference_images = model.get("reference_images", {}) + for key in reference_images: + media.append( + Wan27MediaItem( + type="reference_image", + url=await upload_image_to_comfyapi(cls, image=reference_images[key]), + ) + ) + if not media: + raise ValueError("At least one reference reference image must be provided.") + + initial_response = await sync_op( + cls, + ApiEndpoint( + path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + method="POST", + ), + response_model=TaskCreationResponse, + data=Wan27ReferenceVideoTaskCreationRequest( + model=model["model"], + input=Wan27ReferenceVideoInputField( + prompt=model["prompt"], + negative_prompt=None, + media=media, + ), + parameters=Wan27ReferenceVideoParametersField( + resolution=model["resolution"], + ratio=model["ratio"], + duration=model["duration"], + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=7, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1660,6 +2211,10 @@ class WanApiExtension(ComfyExtension): Wan2VideoContinuationApi, Wan2VideoEditApi, Wan2ReferenceVideoApi, + HappyHorseTextToVideoApi, + HappyHorseImageToVideoApi, + HappyHorseVideoEditApi, + HappyHorseReferenceVideoApi, ] From 1233f077b1b96ec1f8c7c39e83bbe1a734b36424 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Tue, 28 Apr 2026 01:06:03 +0800 Subject: [PATCH 002/102] chore: update workflow templates to v0.9.63 (#13586) Co-authored-by: Jedrzej Kosinski --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6c7457e03..66a130a9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.62 +comfyui-workflow-templates==0.9.63 comfyui-embedded-docs==0.4.4 torch torchsde From 75143eeb06b14bc93db71de207945f6f888be4e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Apr 2026 13:24:36 -0400 Subject: [PATCH 003/102] ComfyUI v0.20.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 2a1eb9905..9c547a228 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.19.3" +__version__ = "0.20.0" diff --git a/pyproject.toml b/pyproject.toml index 8fa92ecbe..785837c09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.19.3" +version = "0.20.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 64b8457f55cd7fb54ca7a956d9c73b505e903e0c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Apr 2026 16:10:14 -0400 Subject: [PATCH 004/102] ComfyUI v0.20.1 because github is broken again and messed up my release. --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 9c547a228..53e7156e3 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.20.0" +__version__ = "0.20.1" diff --git a/pyproject.toml b/pyproject.toml index 785837c09..633dac517 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.20.0" +version = "0.20.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 3cbf015578ac04c30b10078887a774a4b4e45fe4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:44:12 -0700 Subject: [PATCH 005/102] Read audio and video at the same time in video loader node. (#13591) --- comfy_api/latest/_input_impl/video_types.py | 135 ++++++++++++-------- 1 file changed, 82 insertions(+), 53 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index eb4d3701d..812b3eb30 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -12,6 +12,7 @@ import numpy as np import math import torch from .._util import VideoContainer, VideoCodec, VideoComponents +import logging def container_to_output_format(container_format: str | None) -> str | None: @@ -238,32 +239,86 @@ class VideoFromFile(VideoInput): start_time = max(self._get_raw_duration() + self.__start_time, 0) else: start_time = self.__start_time + # Get video frames frames = [] + audio_frames = [] alphas = None start_pts = int(start_time / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base) - container.seek(start_pts, stream=video_stream) - image_format = 'gbrpf32le' - for frame in container.decode(video_stream): - if alphas is None: - for comp in frame.format.components: - if comp.is_alpha: - alphas = [] - image_format = 'gbrapf32le' - break - if frame.pts < start_pts: - continue - if self.__duration and frame.pts >= end_pts: + if start_pts != 0: + container.seek(start_pts, stream=video_stream) + + image_format = 'gbrpf32le' + audio = None + + streams = [video_stream] + has_first_audio_frame = False + checked_alpha = False + + # Default to False so we decode until EOF if duration is 0 + video_done = False + audio_done = True + + if len(container.streams.audio): + audio_stream = container.streams.audio[-1] + streams += [audio_stream] + resampler = av.audio.resampler.AudioResampler(format='fltp') + audio_done = False + + for packet in container.demux(*streams): + if video_done and audio_done: break - img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) - if alphas is None: - frames.append(torch.from_numpy(img)) - else: - frames.append(torch.from_numpy(img[..., :-1])) - alphas.append(torch.from_numpy(img[..., -1:])) + if packet.stream.type == "video": + if video_done: + continue + try: + for frame in packet.decode(): + if frame.pts < start_pts: + continue + if self.__duration and frame.pts >= end_pts: + video_done = True + break + + if not checked_alpha: + for comp in frame.format.components: + if comp.is_alpha: + alphas = [] + image_format = 'gbrapf32le' + break + checked_alpha = True + + img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if alphas is None: + frames.append(torch.from_numpy(img)) + else: + frames.append(torch.from_numpy(img[..., :-1])) + alphas.append(torch.from_numpy(img[..., -1:])) + except av.error.InvalidDataError: + logging.info("pyav decode error") + + elif packet.stream.type == "audio": + if audio_done: + continue + + aframes = itertools.chain.from_iterable( + map(resampler.resample, packet.decode()) + ) + for frame in aframes: + if self.__duration and frame.time > start_time + self.__duration: + audio_done = True + break + + if not has_first_audio_frame: + offset_seconds = start_time - frame.pts * audio_stream.time_base + to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) + if to_skip < frame.samples: + has_first_audio_frame = True + audio_frames.append(frame.to_ndarray()[..., to_skip:]) + else: + audio_frames.append(frame.to_ndarray()) images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) if alphas is not None: @@ -272,42 +327,16 @@ class VideoFromFile(VideoInput): # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) - # Get audio if available - audio = None - container.seek(start_pts, stream=video_stream) - # Use last stream for consistency - if len(container.streams.audio): - audio_stream = container.streams.audio[-1] - audio_frames = [] - resample = av.audio.resampler.AudioResampler(format='fltp').resample - frames = itertools.chain.from_iterable( - map(resample, container.decode(audio_stream)) - ) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + if self.__duration: + audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] - has_first_frame = False - for frame in frames: - offset_seconds = start_time - frame.pts * audio_stream.time_base - to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) - if to_skip < frame.samples: - has_first_frame = True - break - if has_first_frame: - audio_frames.append(frame.to_ndarray()[..., to_skip:]) - - for frame in frames: - if self.__duration and frame.time > start_time + self.__duration: - break - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - if self.__duration: - audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] - - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, - }) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, + }) metadata = container.metadata return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata) From b47f15f25a2a96b5e9fd7efb4ffa5d988038d6ff Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:22:31 +1000 Subject: [PATCH 006/102] fix: Handle un-inited meta-tensors in models (fixes a CPU TE crash) (CORE-67) (#13578) --- comfy/model_patcher.py | 5 ++++- comfy/ops.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ee56f8523..e259aed63 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -31,6 +31,7 @@ import comfy.float import comfy.hooks import comfy.lora import comfy.model_management +import comfy.ops import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction @@ -856,7 +857,9 @@ class ModelPatcher: if m.comfy_patched_weights == True: continue - for param in params: + for param, param_value in params.items(): + if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False): + comfy.ops.disable_weight_init._zero_init_parameter(m, param) key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) diff --git a/comfy/ops.py b/comfy/ops.py index 7a9b4b84c..050f7cda0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): +def materialize_meta_param(s, param_keys): + for param_key in param_keys: + param = getattr(s, param_key, None) + if param is not None and getattr(param, "is_meta", False): + setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) + +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): #vbar doesn't support CPU weights, but some custom nodes have weird paths #that might switch the layer to the CPU and expect it to work. We have to take #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) weight = s.weight.to(dtype=dtype, copy=True) if isinstance(weight, QuantizedTensor): weight = weight.dequantize() @@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if not resident: + materialize_meta_param(s, ["weight", "bias"]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None @@ -306,6 +314,12 @@ class CastWeightBiasOp: bias_function = [] class disable_weight_init: + @staticmethod + def _zero_init_parameter(module, name): + param = getattr(module, name) + device = None if getattr(param, "is_meta", False) else param.device + setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False)) + @staticmethod def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, missing_keys, unexpected_keys, weight_shape, From ed201fff08fbbd3dbcc500b252a9f41e8051c256 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Mon, 27 Apr 2026 19:51:33 -0700 Subject: [PATCH 007/102] ci: dispatch tag push to Comfy-Org/cloud (#13541) Fires on v* tag push (earlier than release.published, which can lag) and triggers a repository_dispatch on Comfy-Org/cloud with event_type comfyui_tag_pushed. Legacy desktop dispatch in release-webhook.yml is left untouched. --- .github/workflows/tag-dispatch-cloud.yml | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 .github/workflows/tag-dispatch-cloud.yml diff --git a/.github/workflows/tag-dispatch-cloud.yml b/.github/workflows/tag-dispatch-cloud.yml new file mode 100644 index 000000000..53a0e91d6 --- /dev/null +++ b/.github/workflows/tag-dispatch-cloud.yml @@ -0,0 +1,45 @@ +name: Tag Dispatch to Cloud + +on: + push: + tags: + - 'v*' + +jobs: + dispatch-cloud: + runs-on: ubuntu-latest + steps: + - name: Send repository dispatch to cloud + env: + DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }} + RELEASE_TAG: ${{ github.ref_name }} + run: | + set -euo pipefail + + if [ -z "${DISPATCH_TOKEN:-}" ]; then + echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set." + exit 1 + fi + + RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}" + + PAYLOAD="$(jq -n \ + --arg release_tag "$RELEASE_TAG" \ + --arg release_url "$RELEASE_URL" \ + '{ + event_type: "comfyui_tag_pushed", + client_payload: { + release_tag: $release_tag, + release_url: $release_url + } + }')" + + curl -fsSL \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${DISPATCH_TOKEN}" \ + https://api.github.com/repos/Comfy-Org/cloud/dispatches \ + -d "$PAYLOAD" + + echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud" From c0d77a5d53828b8027a4f333e41473253150b614 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Tue, 28 Apr 2026 15:59:59 +0800 Subject: [PATCH 008/102] Change the `save 3d model` node's filename prefix to `3d/ComfyUI` (CORE-106) (#12826) * Change save 3d model's filename prefix to 3d/ComfyUI As this node has already changed from `Save GLB` to `Save 3D Model`, using the filename prefix `3d` will be better than `mesh` * use lowercase --------- --- comfy_extras/nodes_hunyuan3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index df0c3e4b1..fa55ead59 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -637,7 +637,7 @@ class SaveGLB(IO.ComfyNode): ], tooltip="Mesh or 3D file to save", ), - IO.String.Input("filename_prefix", default="mesh/ComfyUI"), + IO.String.Input("filename_prefix", default="3d/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] ) From 24de8dc01bc6c857be12f25ba24fb753a48cb0c2 Mon Sep 17 00:00:00 2001 From: Gilad Schreiber Date: Tue, 28 Apr 2026 11:18:19 +0300 Subject: [PATCH 009/102] Fix SolidMask and MaskComposite device mismatch with --gpu-only (#13296) SolidMask had a hardcoded device="cpu" while other nodes (e.g. EmptyImage) follow intermediate_device(). This causes a RuntimeError when MaskComposite combines masks from different device sources under --gpu-only. - SolidMask: use intermediate_device() instead of hardcoded "cpu" - MaskComposite: align source device to destination before operating Co-authored-by: Alexis Rolland Co-authored-by: Jedrzej Kosinski --- comfy_extras/nodes_mask.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index c44602597..8ca947718 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import comfy.model_management import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, UI @@ -188,7 +189,7 @@ class SolidMask(IO.ComfyNode): @classmethod def execute(cls, value, width, height) -> IO.NodeOutput: - out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") + out = torch.full((1, height, width), value, dtype=torch.float32, device=comfy.model_management.intermediate_device()) return IO.NodeOutput(out) solid = execute # TODO: remove @@ -262,6 +263,7 @@ class MaskComposite(IO.ComfyNode): def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) + source = source.to(output.device) left, top = (x, y,) right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) From 13519934ba4220bba47e51c185a63fc837c3d6e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:27:42 -0700 Subject: [PATCH 010/102] Handle metadata rotation in pyav code. (#13605) --- comfy_api/latest/_input_impl/video_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 812b3eb30..b2daa3d7d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -291,6 +291,9 @@ class VideoFromFile(VideoInput): checked_alpha = True img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if frame.rotation != 0: + k = int(round(frame.rotation // 90)) + img = np.rot90(img, k=k, axes=(0, 1)).copy() if alphas is None: frames.append(torch.from_numpy(img)) else: From e514119e1e3b73d5f4190295f3847f07ba228ea8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 29 Apr 2026 06:34:37 +1000 Subject: [PATCH 011/102] comfy-aimdo v0.3.0 (#13604) Comfy-aimdo 0.3.0 contains several major new features. multi-GPU support ARM support AMD support Refactorings include: Linkless architecture - linkage is now performed purely at runtime to stop host library lookups completely and only interact with the torch-loaded Nvidia stack. Elimination of cudart integration on linux. Its no consistent with windows. Misc bugfixes and minor features. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 66a130a9b..12c5ff7a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.2.14 +comfy-aimdo==0.3.0 requests simpleeval>=1.0.0 blake3 From c7a517c2f9d182ea777c7e625ef532865dcff8b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:59:55 -0700 Subject: [PATCH 012/102] Make pyav loading code handle tRNS PNG. (#13607) --- comfy_api/latest/_input_impl/video_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index b2daa3d7d..6ed41bba8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -284,7 +284,7 @@ class VideoFromFile(VideoInput): if not checked_alpha: for comp in frame.format.components: - if comp.is_alpha: + if comp.is_alpha or frame.format.name == "pal8": alphas = [] image_format = 'gbrapf32le' break From dae3d3475179fd796e2901e7d1f9e00aeb515a2f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:15:06 -0700 Subject: [PATCH 013/102] Use pyav to load images instead of pillow. (#13594) On failure (ex: animated webp files) fallback to old pillow code. This should fix the extra precision in high bit depth images (like 16 bit PNG) being discarded when loaded by Pillow and potentially add support for more image formats. --- nodes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index fb83da896..e73a0712e 100644 --- a/nodes.py +++ b/nodes.py @@ -32,7 +32,7 @@ import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions -from comfy_api.latest import io, ComfyExtension +from comfy_api.latest import io, ComfyExtension, InputImpl import comfy.clip_vision @@ -1716,6 +1716,10 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + components = InputImpl.VideoFromFile(image_path).get_components() + if components.images.shape[0] > 0: + return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) + img = node_helpers.pillow(Image.open, image_path) output_images = [] From fce0398470fe3ecdb7ab4c5c69555ad0fcbdc09e Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 29 Apr 2026 09:15:02 +1000 Subject: [PATCH 014/102] dynamicVRAM + --cache-ram 2 (CORE-117) (#13603) * pinned_memory: remove JIT RAM pressure release This doesn't work, as freeing intermediates for pins needs to be higher-priority than freeing pins-for-pins if and when you are going to do that. So this is too late as pins-for-pins is model load time and we dont have JIT pins-for-pins. * cacheing: Add a filter to only free intermediates from inactive wfs This is to get priorities in amongst pins straight. * mm: free inactive-ram from RAM cache first Stuff from inactive workflows should be freed before anything else. * caching: purge old ModelPatchers first Dont try and score them, just dump them at the first sign of trouble if they arent part of the workflow. --- comfy/model_management.py | 1 + comfy/pinned_memory.py | 6 ------ comfy_execution/caching.py | 8 +++++++- execution.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b39d6080..95af40012 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -663,6 +663,7 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() + comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6f142282d..6d3ba367a 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,7 +2,6 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch -import psutil from comfy.cli_args import args @@ -12,11 +11,6 @@ def get_pin(module): def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return - #FIXME: This is a RAM cache trigger event - ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM - #we split the difference and assume half the RAM cache headroom is for us - if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5): - comfy.memory_management.extra_ram_release(ram_headroom) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index f9c913bdb..ba1e8bc84 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -5,6 +5,7 @@ import psutil import time import torch from typing import Sequence, Mapping, Dict +from comfy.model_patcher import ModelPatcher from comfy_execution.graph import DynamicPrompt from abc import ABC, abstractmethod @@ -523,13 +524,15 @@ class RAMPressureCache(LRUCache): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() super().set_local(node_id, value) - def ram_release(self, target): + def ram_release(self, target, free_active=False): if psutil.virtual_memory().available >= target: return clean_list = [] for key, cache_entry in self.cache.items(): + if not free_active and self.used_generation[key] == self.generation: + continue oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -542,6 +545,9 @@ class RAMPressureCache(LRUCache): scan_list_for_ram_usage(output) elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': ram_usage += output.numel() * output.element_size() + elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation: + #old ModelPatchers are the first to go + ram_usage = 1e30 scan_list_for_ram_usage(cache_entry.outputs) oom_score *= ram_usage diff --git a/execution.py b/execution.py index e15eb4bda..5a6d3404c 100644 --- a/execution.py +++ b/execution.py @@ -779,7 +779,7 @@ class PromptExecutor: if self.cache_type == CacheType.RAM_PRESSURE: comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) - comfy.memory_management.extra_ram_release(ram_headroom) + ram_release_callback(ram_headroom, free_active=True) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed From 0e25a6936ef41a56af87a4af174fa519da73b37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:15:10 +0300 Subject: [PATCH 015/102] Reduce video tiny VAE peak VRAM and decode time (CORE-127) (#13617) * Update taehv.py * Simplify * Simplify pixel_unshuffle dispatch --- comfy/taesd/taehv.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 6c06ce19d..696013200 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -7,6 +7,7 @@ from tqdm.auto import tqdm from collections import namedtuple, deque import comfy.ops +import comfy.model_management operations=comfy.ops.disable_weight_init DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) @@ -47,11 +48,14 @@ class TGrow(nn.Module): x = self.conv(x) return x.reshape(-1, C, H, W) -def apply_model_with_memblocks(model, x, parallel, show_progress_bar): +def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None, + patch_size=1, decode=False): B, T, C, H, W = x.shape if parallel: x = x.reshape(B*T, C, H, W) + if not decode and patch_size > 1: + x = F.pixel_unshuffle(x, patch_size) # parallel over input timesteps, iterate over blocks for b in tqdm(model, disable=not show_progress_bar): if isinstance(b, MemBlock): @@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): x = b(x, mem) else: x = b(x) - BT, C, H, W = x.shape - T = BT // B - x = x.view(B, T, C, H, W) + if decode and patch_size > 1: + x = F.pixel_shuffle(x, patch_size) + x = x.view(B, x.shape[0] // B, *x.shape[1:]) + x = x.to(output_device) else: out = [] - work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views). + # Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode). + work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)]) progress_bar = tqdm(range(T), disable=not show_progress_bar) mem = [None] * len(model) while work_queue: xt, i = work_queue.popleft() if i == 0: progress_bar.update(1) + if not decode and patch_size > 1: + xt = F.pixel_unshuffle(xt, patch_size) if i == len(model): - out.append(xt) + if decode and patch_size > 1: + xt = F.pixel_shuffle(xt, patch_size) + out.append(xt.to(output_device)) del xt else: b = model[i] @@ -165,24 +176,20 @@ class TAEHV(nn.Module): def encode(self, x, **kwargs): x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if self.patch_size > 1: - B, T, C, H, W = x.shape - x = x.reshape(B * T, C, H, W) - x = F.pixel_unshuffle(x, self.patch_size) - x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) if x.shape[1] % self.t_downscale != 0: # pad at end to multiple of t_downscale n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) - x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar, + patch_size=self.patch_size).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: - x = F.pixel_shuffle(x, self.patch_size) + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar, + output_device=comfy.model_management.intermediate_device(), + patch_size=self.patch_size, decode=True) return x[:, self.frames_to_trim:].movedim(2, 1) From 5eeae3f1d823e3f072896d6c72185e3c84373739 Mon Sep 17 00:00:00 2001 From: Talmaj Date: Thu, 30 Apr 2026 01:30:08 +0200 Subject: [PATCH 016/102] Cogvideox (#13402) --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Talmaj Marinc --- comfy/latent_formats.py | 7 + comfy/ldm/cogvideo/__init__.py | 0 comfy/ldm/cogvideo/model.py | 573 ++++++++++++++++++++++++++++++++ comfy/ldm/cogvideo/vae.py | 566 +++++++++++++++++++++++++++++++ comfy/model_base.py | 60 ++++ comfy/model_detection.py | 48 +++ comfy/model_sampling.py | 24 ++ comfy/sd.py | 12 + comfy/supported_models.py | 49 ++- comfy/text_encoders/cogvideo.py | 6 + nodes.py | 2 +- 11 files changed, 1345 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/cogvideo/__init__.py create mode 100644 comfy/ldm/cogvideo/model.py create mode 100644 comfy/ldm/cogvideo/vae.py create mode 100644 comfy/text_encoders/cogvideo.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..0f4059ebe 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass + +class CogVideoX(LatentFormat): + latent_channels = 16 + latent_dimensions = 3 + + def __init__(self): + self.scale_factor = 1.15258426 diff --git a/comfy/ldm/cogvideo/__init__.py b/comfy/ldm/cogvideo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py new file mode 100644 index 000000000..fb475ed53 --- /dev/null +++ b/comfy/ldm/cogvideo/model.py @@ -0,0 +1,573 @@ +# CogVideoX 3D Transformer - ported to ComfyUI native ops +# Architecture reference: diffusers CogVideoXTransformer3DModel +# Style reference: comfy/ldm/wan/model.py + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.patcher_extension +import comfy.ldm.common_dit + + +def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0): + """Returns (cos, sin) each with shape [seq_len, dim]. + + Frequencies are computed at dim//2 resolution then repeat_interleaved + to full dim, matching CogVideoX's interleaved (real, imag) pair format. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim)) + angles = torch.outer(pos.float(), freqs.float()) + cos = angles.cos().repeat_interleave(2, dim=-1).float() + sin = angles.sin().repeat_interleave(2, dim=-1).float() + return (cos, sin) + + +def apply_rotary_emb(x, freqs_cos_sin): + """Apply CogVideoX rotary embedding to query or key tensor. + + x: [B, heads, seq_len, head_dim] + freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2] + + Uses interleaved pair rotation (same as diffusers CogVideoX/Flux). + head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back. + """ + cos, sin = freqs_cos_sin + cos = cos[None, None, :, :].to(x.device) + sin = sin[None, None, :, :].to(x.device) + + # Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half) + args = timesteps[:, None].float() * freqs[None] * scale + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if flip_sin_to_cos: + embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None): + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale + + grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + + embed_dim_spatial = 2 * (embed_dim // 3) + embed_dim_temporal = embed_dim // 3 + + pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device) + pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device) + + T, H, W = grid_t.shape + pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1) + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) + + return pos_embed + + +def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None): + T, H, W = grid_h.shape + half_dim = embed_dim // 2 + pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim) + pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim) + return torch.cat([pos_h, pos_w], dim=-1) + + +def _get_1d_sincos_pos_embed(embed_dim, pos, device=None): + half = embed_dim // 2 + freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half) + args = pos.float().reshape(-1)[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if embed_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + + +class CogVideoXPatchEmbed(nn.Module): + def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920, + text_dim=4096, bias=True, sample_width=90, sample_height=60, + sample_frames=49, temporal_compression_ratio=4, + max_text_seq_length=226, spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, use_positional_embeddings=True, + use_learned_positional_embeddings=True, + device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.dim = dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + if patch_size_t is None: + self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype) + else: + self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype) + + self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None): + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + if self.patch_size_t is not None: + post_time_compression_frames = post_time_compression_frames // self.patch_size_t + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=device, + ) + pos_embedding = pos_embedding.reshape(-1, self.dim) + joint_pos_embedding = pos_embedding.new_zeros( + 1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + return joint_pos_embedding + + def forward(self, text_embeds, image_embeds): + input_dtype = text_embeds.dtype + text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype) + batch_size, num_frames, channels, height, width = image_embeds.shape + + proj_dtype = self.proj.weight.dtype + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) + image_embeds = image_embeds.flatten(1, 2) + else: + p = self.patch_size + p_t = self.patch_size_t + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + text_seq_length = text_embeds.shape[1] + num_image_patches = image_embeds.shape[1] + + if self.use_learned_positional_embeddings: + image_pos = self.pos_embedding[ + :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches + ].to(device=embeds.device, dtype=embeds.dtype) + else: + image_pos = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype) + + # Build joint: zeros for text + sincos for image + joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = image_pos + embeds = embeds + joint_pos + + return embeds + + +class CogVideoXLayerNormZero(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb): + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +class CogVideoXAdaLayerNorm(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class CogVideoXBlock(nn.Module): + def __init__(self, dim, num_heads, head_dim, time_dim, + eps=1e-5, ff_inner_dim=None, ff_bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Self-attention (joint text + latent) + self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Feed-forward (GELU approximate) + inner_dim = ff_inner_dim or dim * 4 + self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype) + self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None): + if transformer_options is None: + transformer_options = {} + text_seq_length = encoder_hidden_states.size(1) + + # Norm & modulate + norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + + # Joint self-attention + qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1) + b, s, _ = qkv_input.shape + n, d = self.num_heads, self.head_dim + + q = self.q(qkv_input).view(b, s, n, d) + k = self.k(qkv_input).view(b, s, n, d) + v = self.v(qkv_input) + + q = self.norm_q(q).view(b, s, n, d) + k = self.norm_k(k).view(b, s, n, d) + + # Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim]) + if image_rotary_emb is not None: + q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim] + k_img = k[:, text_seq_length:].transpose(1, 2) + q_img = apply_rotary_emb(q_img, image_rotary_emb) + k_img = apply_rotary_emb(k_img, image_rotary_emb) + q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1) + k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1) + + attn_out = optimized_attention( + q.reshape(b, s, n * d), + k.reshape(b, s, n * d), + v, + heads=self.num_heads, + transformer_options=transformer_options, + ) + + attn_out = self.attn_out(attn_out) + + attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1) + + hidden_states = hidden_states + gate_msa * attn_hidden + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder + + # Norm & modulate for FF + norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + + # Feed-forward (GELU on concatenated text + latent) + ff_input = torch.cat([norm_encoder, norm_hidden], dim=1) + ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh")) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(nn.Module): + def __init__(self, + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + ofs_embed_dim=None, + text_embed_dim=4096, + num_layers=30, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + patch_size_t=None, + temporal_compression_ratio=4, + max_text_seq_length=226, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + use_rotary_positional_embeddings=False, + use_learned_positional_embeddings=False, + patch_bias=True, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + dim = num_attention_heads * attention_head_dim + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.max_text_seq_length = max_text_seq_length + self.use_rotary_positional_embeddings = use_rotary_positional_embeddings + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + dim=dim, + text_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + device=device, dtype=torch.float32, operations=operations, + ) + + # 2. Time embedding + self.time_proj_dim = dim + self.time_proj_flip = flip_sin_to_cos + self.time_proj_shift = freq_shift + self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype) + self.time_embedding_act = nn.SiLU() + self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype) + + # Optional OFS embedding (CogVideoX 1.5 I2V) + self.ofs_proj_dim = ofs_embed_dim + if ofs_embed_dim: + self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + self.ofs_embedding_act = nn.SiLU() + self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + else: + self.ofs_embedding_linear_1 = None + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + CogVideoXBlock( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + time_dim=time_embed_dim, + eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype) + + # 4. Output + self.norm_out = CogVideoXAdaLayerNorm( + time_dim=time_embed_dim, dim=dim, eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + + if patch_size_t is None: + output_dim = patch_size * patch_size * out_channels + else: + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype) + + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.temporal_compression_ratio = temporal_compression_ratio + + def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, ofs, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} + # ComfyUI passes [B, C, T, H, W] + batch_size, channels, t, h, w = x.shape + + # Pad to patch size (temporal + spatial), same pattern as WAN + p_t = self.patch_size_t if self.patch_size_t is not None else 1 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size)) + + # CogVideoX expects [B, T, C, H, W] + x = x.permute(0, 2, 1, 3, 4) + batch_size, num_frames, channels, height, width = x.shape + + # Time embedding + t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift) + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb))) + + if self.ofs_embedding_linear_1 is not None and ofs is not None: + ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift) + ofs_emb = ofs_emb.to(dtype=x.dtype) + ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb))) + emb = emb + ofs_emb + + # Patch embedding + hidden_states = self.patch_embed(context, x) + + text_seq_length = context.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Rotary embeddings (if used) + image_rotary_emb = None + if self.use_rotary_positional_embeddings: + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + if self.patch_size_t is None: + post_time = num_frames + else: + post_time = num_frames // self.patch_size_t + image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device) + + # Transformer blocks + for i, block in enumerate(self.blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + hidden_states = self.norm_final(hidden_states) + + # Output projection + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # Unpatchify + p = self.patch_size + p_t = self.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + # Back to ComfyUI format [B, C, T, H, W] and crop padding + output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w] + return output + + def _get_rotary_emb(self, h, w, t, device): + """Compute CogVideoX 3D rotary positional embeddings. + + For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions + are integer arange computed at max_size, then sliced to actual size. + For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords + scaled by spatial_interpolation_scale. + """ + d = self.attention_head_dim + dim_t = d // 4 + dim_h = d // 8 * 3 + dim_w = d // 8 * 3 + + if self.patch_size_t is not None: + # CogVideoX 1.5: "slice" mode — positions are simple integer indices + # Compute at max(sample_size, actual_size) then slice to actual + base_h = self.patch_embed.sample_height // self.patch_size + base_w = self.patch_embed.sample_width // self.patch_size + max_h = max(base_h, h) + max_w = max(base_w, w) + + grid_h = torch.arange(max_h, device=device, dtype=torch.float32) + grid_w = torch.arange(max_w, device=device, dtype=torch.float32) + grid_t = torch.arange(t, device=device, dtype=torch.float32) + else: + # CogVideoX 1.0: "linspace" mode with interpolation scale + grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_t = torch.arange(t, device=device, dtype=torch.float32) + + freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t) + freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h) + freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w) + + t_cos, t_sin = freqs_t + h_cos, h_sin = freqs_h + w_cos, w_sin = freqs_w + + # Slice to actual size (for "slice" mode where grids may be larger) + t_cos, t_sin = t_cos[:t], t_sin[:t] + h_cos, h_sin = h_cos[:h], h_sin[:h] + w_cos, w_sin = w_cos[:w], w_sin[:w] + + # Broadcast and concatenate into [T*H*W, head_dim] + t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1) + t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1) + h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1) + h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1) + w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1) + w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1) + + cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1) + sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1) + return (cos, sin) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py new file mode 100644 index 000000000..d4e6f321e --- /dev/null +++ b/comfy/ldm/cogvideo/vae.py @@ -0,0 +1,566 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding. + + Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has + a single temporal frame and no cache, the 3D conv weight is sliced to act + as a 2D conv, avoiding computation on zero-padded temporal dimensions. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + self.time_kernel_size = time_kernel + self.pad_mode = pad_mode + + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0) + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = ops.Conv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=(0, height_pad, width_pad), + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + if conv_cache is None and x.shape[2] == 1: + # Fast path: single frame, no cache. All temporal padding + # frames are copies of the input (replicate-style), so the + # 3D conv reduces to a 2D conv with summed temporal kernel. + w = comfy.ops.cast_to_input(self.conv.weight, x) + b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None + w2d = w.sum(dim=2, keepdim=True) + out = F.conv3d(x, w2d, b, + self.conv.stride, self.conv.padding, + self.conv.dilation, self.conv.groups) + return out, None + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +def _interpolate_zq(zq, target_size): + """Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling.""" + t = target_size[0] + if t > 1 and t % 2 == 1: + z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2])) + z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2])) + return torch.cat([z_first, z_rest], dim=2) + return F.interpolate(zq, size=target_size) + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if zq.shape[-3:] != f.shape[-3:]: + zq = _interpolate_zq(zq, f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run + on the full (low-res) tensor, then the expensive spatial-only up_blocks + + norm_out + conv_out are processed in small temporal chunks with conv_cache + carrying causal state between chunks. This keeps peak VRAM proportional to + chunk_size rather than total frame count. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + self.temporal_compression_ratio = temporal_compression_ratio + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + remainder = t % frame_batch + conv_cache = None + enc = [] + + # Process remainder frames first so only the first chunk can have an + # odd temporal dimension — where Downsample3D's first-frame-special + # handling in temporal compression is actually correct. + if remainder > 0: + chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + + for start in range(remainder, t, frame_batch): + chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + return self._decode_rolling(z) + + def _decode_batched(self, z): + """Original batched decode - processes 2 latent frames through full decoder.""" + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) + + def _decode_rolling(self, z): + """Rolling decode - processes low-res layers on full tensor, then rolls + through expensive high-res layers in temporal chunks.""" + decoder = self.decoder + device = z.device + + # Determine which up_blocks have temporal upsample vs spatial-only. + # Temporal up_blocks are cheap (low res), spatial-only are expensive. + temporal_compress_level = int(np.log2(self.temporal_compression_ratio)) + split_at = temporal_compress_level # first N up_blocks do temporal upsample + + # Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res) + x, _ = decoder.conv_in(z) + x, _ = decoder.mid_block(x, None, z) + + for i in range(split_at): + x, _ = decoder.up_blocks[i](x, None, z) + + # Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks + remaining_blocks = list(range(split_at, len(decoder.up_blocks))) + chunk_size = 4 # pixel frames per chunk through high-res layers + t_expanded = x.shape[2] + + if t_expanded <= chunk_size or len(remaining_blocks) == 0: + # Small enough to process in one go + for i in remaining_blocks: + x, _ = decoder.up_blocks[i](x, None, z) + x, _ = decoder.norm_out(x, z) + x = decoder.conv_act(x) + x, _ = decoder.conv_out(x) + return x + + # Expand z temporally once to match Phase 2's time dimension. + # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB + # for the old approach of pre-interpolating to every pixel resolution). + z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4])) + + # Process in temporal chunks, interpolating spatially per-chunk to avoid + # allocating full [B, C, t_expanded, H, W] tensors at each resolution. + dec_out = [] + conv_caches = {} + + for chunk_start in range(0, t_expanded, chunk_size): + chunk_end = min(chunk_start + chunk_size, t_expanded) + x_chunk = x[:, :, chunk_start:chunk_end] + z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end] + z_spatial_cache = {} + + for i in remaining_blocks: + block = decoder.up_blocks[i] + cache_key = f"up_block_{i}" + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]: + z_spatial_cache[hw_key] = z_t_chunk + else: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key)) + conv_caches[cache_key] = new_cache + + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out")) + conv_caches["norm_out"] = new_cache + x_chunk = decoder.conv_act(x_chunk) + x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) + conv_caches["conv_out"] = new_cache + + dec_out.append(x_chunk.cpu()) + del z_spatial_cache + + del x, z_time_expanded + return torch.cat(dec_out, dim=2).to(device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 787ea1145..50dab5782 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model import comfy.ldm.sam3.detector @@ -81,6 +82,7 @@ class ModelType(Enum): IMG_TO_IMG = 9 FLOW_COSMOS = 10 IMG_TO_IMG_FLOW = 11 + V_PREDICTION_DDPM = 12 def model_sampling(model_config, model_type): @@ -115,6 +117,8 @@ def model_sampling(model_config, model_type): s = comfy.model_sampling.ModelSamplingCosmosRFlow elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + elif model_type == ModelType.V_PREDICTION_DDPM: + c = comfy.model_sampling.V_PREDICTION_DDPM class ModelSampling(s, c): pass @@ -1979,3 +1983,59 @@ class ErnieImage(BaseModel): class SAM3(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) + +class CogVideoX(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel) + self.image_to_video = image_to_video + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent) + extra_channels = self.diffusion_model.in_channels - noise.shape[1] + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape = list(noise.shape) + shape[1] = extra_channels + return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device) + + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + if noise.ndim == 5 and image.ndim == 5: + if image.shape[-3] < noise.shape[-3]: + image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) + elif image.shape[-3] > noise.shape[-3]: + image = image[:, :, :noise.shape[-3]] + + for i in range(0, image.shape[1], latent_dim): + image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + if image.shape[1] > extra_channels: + image = image[:, :extra_channels] + elif image.shape[1] < extra_channels: + repeats = extra_channels // image.shape[1] + remainder = extra_channels % image.shape[1] + parts = [image] * repeats + if remainder > 0: + parts.append(image[:, :remainder]) + image = torch.cat(parts, dim=1) + + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR + if self.diffusion_model.ofs_proj_dim is not None: + ofs = kwargs.get("ofs", None) + if ofs is None: + noise = kwargs.get("noise", None) + ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype) + out['ofs'] = comfy.conds.CONDRegular(ofs) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 724a241bf..d9b67dcdf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX + dit_config = {} + dit_config["image_model"] = "cogvideox" + + # Extract config from weight shapes + norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)] + time_embed_dim = norm1_weight.shape[1] + dim = norm1_weight.shape[0] // 6 + + dit_config["num_attention_heads"] = dim // 64 + dit_config["attention_head_dim"] = 64 + dit_config["time_embed_dim"] = time_embed_dim + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + # Detect in_channels from patch_embed + patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) + if patch_proj_key in state_dict_keys: + w = state_dict[patch_proj_key] + if w.ndim == 4: + # Conv2d: [out, in, kh, kw] — CogVideoX 1.0 + dit_config["in_channels"] = w.shape[1] + dit_config["patch_size"] = w.shape[2] + elif w.ndim == 2: + # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5 + dit_config["patch_size"] = 2 + dit_config["patch_size_t"] = 2 + dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32 + + text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix) + if text_proj_key in state_dict_keys: + dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1] + + # Detect OFS embedding + ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix) + if ofs_key in state_dict_keys: + dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1] + + # Detect positional embedding type + pos_key = '{}patch_embed.pos_embedding'.format(key_prefix) + if pos_key in state_dict_keys: + dit_config["use_learned_positional_embeddings"] = True + dit_config["use_rotary_positional_embeddings"] = False + else: + dit_config["use_learned_positional_embeddings"] = False + dit_config["use_rotary_positional_embeddings"] = True + + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 13860e6a2..cf2b5db5f 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -54,6 +54,30 @@ class V_PREDICTION(EPS): sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class V_PREDICTION_DDPM: + """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v. + x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v + = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1) + """ + def calculate_input(self, sigma, noise): + return noise + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = reshape_sigma(sigma, model_output.ndim) + return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5 + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = reshape_sigma(sigma, noise.ndim) + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent + class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): sigma = reshape_sigma(sigma, model_output.ndim) diff --git a/comfy/sd.py b/comfy/sd.py index 736fe35de..9158317f1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert @@ -652,6 +653,17 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2 + self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels) + self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8886f32d5..92d0305c5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -27,6 +27,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie +import comfy.text_encoders.cogvideo from . import supported_models_base from . import latent_formats @@ -1832,6 +1833,52 @@ class SAM31(SAM3): unet_config = {"image_model": "SAM31"} -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] +class CogVideoX_T2V(supported_models_base.BASE): + unet_config = { + "image_model": "cogvideox", + } + + sampling_settings = { + "linear_start": 0.00085, + "linear_end": 0.012, + "beta_schedule": "linear", + "zsnr": True, + } + + unet_extra_config = {} + latent_format = latent_formats.CogVideoX + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, device=device) + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) + +class CogVideoX_I2V(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31, CogVideoX_I2V, CogVideoX_T2V] models += [SVD_img2vid] diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py new file mode 100644 index 000000000..f1e8e3f5d --- /dev/null +++ b/comfy/text_encoders/cogvideo.py @@ -0,0 +1,6 @@ +import comfy.text_encoders.sd3_clip + + +class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) diff --git a/nodes.py b/nodes.py index e73a0712e..7aeb05b32 100644 --- a/nodes.py +++ b/nodes.py @@ -2463,7 +2463,7 @@ async def init_builtin_extra_nodes(): "nodes_curve.py", "nodes_rtdetr.py", "nodes_frame_interpolation.py", - "nodes_sam3.py" + "nodes_sam3.py", ] import_failed = [] From a164c82913d3e04d92d0f6630fc4c850ec184ef3 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:37:30 -0600 Subject: [PATCH 017/102] Add high quality preview support for Flux2 latents (#13496) --- comfy/latent_formats.py | 1 + comfy/sd.py | 5 +- comfy/taesd/taesd.py | 102 ++++++++++++++++++++++++++++++---------- nodes.py | 53 +++++++-------------- 4 files changed, 100 insertions(+), 61 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 0f4059ebe..3dac5be18 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -224,6 +224,7 @@ class Flux2(LatentFormat): self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + self.taesd_decoder_name = "taef2_decoder" def process_in(self, latent): return latent diff --git a/comfy/sd.py b/comfy/sd.py index 9158317f1..ee66490f5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,7 +479,10 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: - self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] + if isinstance(metadata, dict) and "tae_latent_channels" in metadata: + self.latent_channels = metadata["tae_latent_channels"] + else: + self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index ce36f1a84..05d370209 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -17,32 +17,79 @@ class Clamp(nn.Module): return torch.tanh(x / 3) * 3 class Block(nn.Module): - def __init__(self, n_in, n_out): + def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() - def forward(self, x): + if not use_midblock_gn: + self.pool = None + return + n_gn = n_in * 4 + self.pool = nn.Sequential( + comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False), + comfy.ops.disable_weight_init.GroupNorm(4, n_gn), + nn.ReLU(inplace=True), + comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.pool is not None: + x = x + self.pool(x) return self.fuse(self.conv(x) + self.skip(x)) -def Encoder(latent_channels=4): - return nn.Sequential( - conv(3, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, latent_channels), - ) +class Encoder(nn.Sequential): + def __init__(self, latent_channels: int = 4, use_gn: bool = False): + super().__init__( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), + conv(64, latent_channels), + ) +class Decoder(nn.Sequential): + def __init__(self, latent_channels: int = 4, use_gn: bool = False): + super().__init__( + Clamp(), conv(latent_channels, 64), nn.ReLU(), + Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + +class DecoderFlux2(Decoder): + def __init__(self, latent_channels: int = 128, use_gn: bool = True): + if latent_channels != 128 or not use_gn: + raise ValueError("Unexpected parameters for Flux2 TAE module") + super().__init__(latent_channels=32, use_gn=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + x = ( + x + .reshape(B, 32, 2, 2, H, W) + .permute(0, 1, 4, 2, 5, 3) + .reshape(B, 32, H * 2, W * 2) + ) + return super().forward(x) + +class EncoderFlux2(Encoder): + def __init__(self, latent_channels: int = 128, use_gn: bool = True): + if latent_channels != 128 or not use_gn: + raise ValueError("Unexpected parameters for Flux2 TAE module") + super().__init__(latent_channels=32, use_gn=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + result = super().forward(x) + B, C, H, W = result.shape + return ( + result + .reshape(B, C, H // 2, 2, W // 2, 2) + .permute(0, 1, 3, 5, 2, 4) + .reshape(B, 128, H // 2, W // 2) + ) -def Decoder(latent_channels=4): - return nn.Sequential( - Clamp(), conv(latent_channels, 64), nn.ReLU(), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), - Block(64, 64), conv(64, 3), - ) class TAESD(nn.Module): latent_magnitude = 3 @@ -51,8 +98,15 @@ class TAESD(nn.Module): def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.taesd_encoder = Encoder(latent_channels=latent_channels) - self.taesd_decoder = Decoder(latent_channels=latent_channels) + if latent_channels == 128: + encoder_class = EncoderFlux2 + decoder_class = DecoderFlux2 + else: + encoder_class = Encoder + decoder_class = Decoder + self.taesd_encoder = encoder_class(latent_channels=latent_channels) + self.taesd_decoder = decoder_class(latent_channels=latent_channels) + self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) if encoder_path is not None: @@ -61,19 +115,19 @@ class TAESD(nn.Module): self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod - def scale_latents(x): + def scale_latents(x: torch.Tensor) -> torch.Tensor: """raw latents -> [0, 1]""" return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) @staticmethod - def unscale_latents(x): + def unscale_latents(x: torch.Tensor) -> torch.Tensor: """[0, 1] -> raw latents""" return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) - def decode(self, x): + def decode(self, x: torch.Tensor) -> torch.Tensor: x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) x_sample = x_sample.sub(0.5).mul(2) return x_sample - def encode(self, x): + def encode(self, x: torch.Tensor) -> torch.Tensor: return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift diff --git a/nodes.py b/nodes.py index 7aeb05b32..99dc07227 100644 --- a/nodes.py +++ b/nodes.py @@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader): class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] - image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"] + @staticmethod def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") - sdxl_taesd_enc = False - sdxl_taesd_dec = False - sd1_taesd_enc = False - sd1_taesd_dec = False - sd3_taesd_enc = False - sd3_taesd_dec = False - f1_taesd_enc = False - f1_taesd_dec = False - + have_img_encoder, have_img_decoder = set(), set() for v in approx_vaes: - if v.startswith("taesd_decoder."): - sd1_taesd_dec = True - elif v.startswith("taesd_encoder."): - sd1_taesd_enc = True - elif v.startswith("taesdxl_decoder."): - sdxl_taesd_dec = True - elif v.startswith("taesdxl_encoder."): - sdxl_taesd_enc = True - elif v.startswith("taesd3_decoder."): - sd3_taesd_dec = True - elif v.startswith("taesd3_encoder."): - sd3_taesd_enc = True - elif v.startswith("taef1_encoder."): - f1_taesd_dec = True - elif v.startswith("taef1_decoder."): - f1_taesd_enc = True - else: + parts = v.split("_", 1) + if len(parts) != 2 or parts[0] not in s.image_taes: for tae in s.video_taes: if v.startswith(tae): vaes.append(v) - - if sd1_taesd_dec and sd1_taesd_enc: - vaes.append("taesd") - if sdxl_taesd_dec and sdxl_taesd_enc: - vaes.append("taesdxl") - if sd3_taesd_dec and sd3_taesd_enc: - vaes.append("taesd3") - if f1_taesd_dec and f1_taesd_enc: - vaes.append("taef1") + break + continue + if parts[1].startswith("encoder."): + have_img_encoder.add(parts[0]) + elif parts[1].startswith("decoder."): + have_img_decoder.add(parts[0]) + vaes += [k for k in have_img_decoder if k in have_img_encoder] vaes.append("pixel_space") return vaes @@ -827,6 +803,11 @@ class VAELoader: else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + if vae_name == "taef2": + if metadata is None: + metadata = {"tae_latent_channels": 128} + else: + metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) From d10fc2d6524043d2322968b518168910b1e9b530 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:05:31 -0700 Subject: [PATCH 018/102] Lower peak mem usage for 8 bit formats with pyav. (#13626) --- comfy_api/latest/_input_impl/video_types.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 6ed41bba8..9a107fb76 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -251,6 +251,7 @@ class VideoFromFile(VideoInput): container.seek(start_pts, stream=video_stream) image_format = 'gbrpf32le' + process_image_format = lambda a: a audio = None streams = [video_stream] @@ -283,11 +284,25 @@ class VideoFromFile(VideoInput): break if not checked_alpha: + alpha_channel = False for comp in frame.format.components: if comp.is_alpha or frame.format.name == "pal8": alphas = [] - image_format = 'gbrapf32le' + alpha_channel = True break + if frame.format.name in ("yuvj420p", "rgb24", "rgba", "pal8"): + process_image_format = lambda a: a.float() / 255.0 + if alpha_channel: + image_format = 'rgba' + else: + image_format = 'rgb24' + else: + process_image_format = lambda a: a + if alpha_channel: + image_format = 'gbrapf32le' + else: + image_format = 'gbrpf32le' + checked_alpha = True img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) @@ -323,9 +338,9 @@ class VideoFromFile(VideoInput): else: audio_frames.append(frame.to_ndarray()) - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) + images = process_image_format(torch.stack(frames)) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) if alphas is not None: - alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) + alphas = process_image_format(torch.stack(alphas)) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) From a7d82baa06e6b2e3d19c38c244118909fe270d49 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 29 Apr 2026 20:30:01 -0700 Subject: [PATCH 019/102] Fix SQLAlchemy version format in requirements.txt (#13547) Change SQLAlchemy>=2.0 to SQLAlchemy>=2.0.0 to satisfy the X.Y.Z version format expected by install_util.is_valid_version(). --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 12c5ff7a9..c3d51e2fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ scipy tqdm psutil alembic -SQLAlchemy>=2.0 +SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 From 38ecad8f8af30965eb1017b0eb6a552c751b84a4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 30 Apr 2026 11:09:33 +0300 Subject: [PATCH 020/102] feat(api-nodes): allow custom resolutions for GPTImage2 node (#13631) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_openai.py | 51 +++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index bbb758068..843681817 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -415,8 +415,9 @@ class OpenAIGPTImage1(IO.ComfyNode): "1152x2048", "3840x2160", "2160x3840", + "Custom", ], - tooltip="Image size", + tooltip="Image size. Select 'Custom' to use the custom width and height (GPT Image 2 only).", optional=True, ), IO.Int.Input( @@ -445,6 +446,26 @@ class OpenAIGPTImage1(IO.ComfyNode): default="gpt-image-2", optional=True, ), + IO.Int.Input( + "custom_width", + default=1024, + min=1024, + max=3840, + step=16, + tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).", + optional=True, + advanced=True, + ), + IO.Int.Input( + "custom_height", + default=1024, + min=1024, + max=3840, + step=16, + tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).", + optional=True, + advanced=True, + ), ], outputs=[ IO.Image.Output(), @@ -471,9 +492,9 @@ class OpenAIGPTImage1(IO.ComfyNode): "high": [0.133, 0.22] }, "gpt-image-2": { - "low": [0.0048, 0.012], - "medium": [0.041, 0.112], - "high": [0.165, 0.43] + "low": [0.0048, 0.019], + "medium": [0.041, 0.168], + "high": [0.165, 0.67] } }; $range := $lookup($lookup($ranges, widgets.model), widgets.quality); @@ -503,6 +524,8 @@ class OpenAIGPTImage1(IO.ComfyNode): mask: Input.Image | None = None, n: int = 1, size: str = "1024x1024", + custom_width: int = 1024, + custom_height: int = 1024, model: str = "gpt-image-1", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -510,7 +533,25 @@ class OpenAIGPTImage1(IO.ComfyNode): if mask is not None and image is None: raise ValueError("Cannot use a mask without an input image") - if model in ("gpt-image-1", "gpt-image-1.5"): + if size == "Custom": + if model != "gpt-image-2": + raise ValueError("Custom resolution is only supported by GPT Image 2 model") + if custom_width % 16 != 0 or custom_height % 16 != 0: + raise ValueError(f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}") + if max(custom_width, custom_height) > 3840: + raise ValueError(f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}") + ratio = max(custom_width, custom_height) / min(custom_width, custom_height) + if ratio > 3: + raise ValueError( + f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}" + ) + total_pixels = custom_width * custom_height + if not 655_360 <= total_pixels <= 8_294_400: + raise ValueError( + f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}" + ) + size = f"{custom_width}x{custom_height}" + elif model in ("gpt-image-1", "gpt-image-1.5"): if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"): raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model") From b633244635e577e199944cd4f027df79afa16dbf Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 30 Apr 2026 21:49:08 +0300 Subject: [PATCH 021/102] [Partner Nodes] ByteDance: virtual portrait library for regular images (#13638) * feat(api-nodes-bytedance): use the virtual portrait library for regular images Signed-off-by: bigcat88 * fix: include shape in image dedup hash Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/bytedance.py | 5 ++++ comfy_api_nodes/nodes_bytedance.py | 38 ++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index eafabbefe..c05bd6893 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -157,6 +157,11 @@ class SeedanceCreateAssetResponse(BaseModel): asset_id: str = Field(...) +class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): + url: str = Field(..., description="Publicly accessible URL of the image asset to upload.") + hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.") + + # Dollars per 1K tokens, keyed by (model_id, has_video_input). SEEDANCE2_PRICE_PER_1K_TOKENS = { ("dreamina-seedance-2-0-260128", False): 0.007, diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index de192c5ac..fee0ab888 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,3 +1,4 @@ +import hashlib import logging import math import re @@ -20,6 +21,7 @@ from comfy_api_nodes.apis.bytedance import ( SeedanceCreateAssetResponse, SeedanceCreateVisualValidateSessionResponse, SeedanceGetVisualValidateSessionResponse, + SeedanceVirtualLibraryCreateAssetRequest, Seedream4Options, Seedream4TaskCreationRequest, TaskAudioContent, @@ -271,6 +273,30 @@ async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_i ) +async def _seedance_virtual_library_upload_image_asset( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + wait_label: str = "Uploading image", +) -> str: + """Upload an image into the caller's per-customer Seedance virtual library.""" + public_url = await upload_image_to_comfyapi(cls, image, wait_label=wait_label) + normalized = image.detach().cpu().contiguous().to(torch.float32) + digest = hashlib.sha256() + digest.update(str(tuple(normalized.shape)).encode("utf-8")) + digest.update(b"\0") + digest.update(normalized.numpy().tobytes()) + image_hash = digest.hexdigest() + create_resp = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=image_hash), + ) + await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library") + return f"asset://{create_resp.asset_id}" + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -1507,7 +1533,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): if first_frame_asset_id: first_frame_url = image_assets[first_frame_asset_id] else: - first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") + first_frame_url = await _seedance_virtual_library_upload_image_asset( + cls, first_frame, wait_label="Uploading first frame." + ) content: list[TaskTextContent | TaskImageContent] = [ TaskTextContent(text=model["prompt"]), @@ -1527,7 +1555,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): content.append( TaskImageContent( image_url=TaskImageContentUrl( - url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.") + url=await _seedance_virtual_library_upload_image_asset( + cls, last_frame, wait_label="Uploading last frame." + ) ), role="last_frame", ), @@ -1805,9 +1835,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode): content.append( TaskImageContent( image_url=TaskImageContentUrl( - url=await upload_image_to_comfyapi( + url=await _seedance_virtual_library_upload_image_asset( cls, - image=reference_images[key], + reference_images[key], wait_label=f"Uploading image {i}", ), ), From e6e0936128858608c5cc45585be3583176d748b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:33:09 -0700 Subject: [PATCH 022/102] Load other jpeg formats without taking so much memory. (#13642) --- comfy_api/latest/_input_impl/video_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 9a107fb76..942278d88 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -290,7 +290,7 @@ class VideoFromFile(VideoInput): alphas = [] alpha_channel = True break - if frame.format.name in ("yuvj420p", "rgb24", "rgba", "pal8"): + if frame.format.name in ("yuvj420p", "yuvj422p", "yuvj444p", "rgb24", "rgba", "pal8"): process_image_format = lambda a: a.float() / 255.0 if alpha_channel: image_format = 'rgba' From e9c311b2458a327585a7e387558377c8190eebb0 Mon Sep 17 00:00:00 2001 From: Rainer Date: Fri, 1 May 2026 02:33:41 +0300 Subject: [PATCH 023/102] OneTainer ERNIE LoRA support (#13640) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 63ee85323..e4337c729 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -342,6 +342,12 @@ def model_lora_keys_unet(model, key_map={}): key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format + if isinstance(model, comfy.model_base.ErnieImage): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["transformer.{}".format(key_lora)] = k + return key_map From e8e8fee22476a926090df9f719acd0a553ff8165 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Fri, 1 May 2026 09:14:28 +0800 Subject: [PATCH 024/102] chore: update workflow templates to v0.9.65 (#13644) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c3d51e2fa..cb85d970b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.63 +comfyui-workflow-templates==0.9.65 comfyui-embedded-docs==0.4.4 torch torchsde From 97f58baaaf89e2232b735fab2a3f2d4e24d134c3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 30 Apr 2026 18:49:31 -0700 Subject: [PATCH 025/102] Add alexisrolland and rattus128 as code owners (#13648) --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 4d5448636..e693955a0 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,2 @@ # Admins -* @comfyanonymous @kosinkadink @guill +* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 From 96f1cee9f5304c1f4e3a176ed02a44cf0a0116ad Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 1 May 2026 09:15:11 +0300 Subject: [PATCH 026/102] chore(api-nodes): always display the custom width and height in GPTImage2 node (#13651) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_openai.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 843681817..21fe470ce 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -454,7 +454,6 @@ class OpenAIGPTImage1(IO.ComfyNode): step=16, tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).", optional=True, - advanced=True, ), IO.Int.Input( "custom_height", @@ -464,7 +463,6 @@ class OpenAIGPTImage1(IO.ComfyNode): step=16, tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).", optional=True, - advanced=True, ), ], outputs=[ From cf9cbec5960e38368393137419637f6b9ca7691b Mon Sep 17 00:00:00 2001 From: Talmaj Date: Fri, 1 May 2026 11:20:11 +0200 Subject: [PATCH 027/102] Reformat models variable into multiline array CORE-59 (#13513) Co-authored-by: Talmaj Marinc --- comfy/supported_models.py | 84 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 92d0305c5..e6c17fb98 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1879,6 +1879,86 @@ class CogVideoX_I2V(CogVideoX_T2V): out = model_base.CogVideoX(self, image_to_video=True, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31, CogVideoX_I2V, CogVideoX_T2V] -models += [SVD_img2vid] +models = [ + LotusD, + Stable_Zero123, + SD15_instructpix2pix, + SD15, + SD20, + SD21UnclipL, + SD21UnclipH, + SDXL_instructpix2pix, + SDXLRefiner, + SDXL, + SSD1B, + KOALA_700M, + KOALA_1B, + Segmind_Vega, + SD_X4Upscaler, + Stable_Cascade_C, + Stable_Cascade_B, + SV3D_u, + SV3D_p, + SD3, + StableAudio, + AuraFlow, + PixArtAlpha, + PixArtSigma, + HunyuanDiT, + HunyuanDiT1, + FluxInpaint, + Flux, + LongCatImage, + FluxSchnell, + GenmoMochi, + LTXV, + LTXAV, + HunyuanVideo15_SR_Distilled, + HunyuanVideo15, + HunyuanImage21Refiner, + HunyuanImage21, + HunyuanVideoSkyreelsI2V, + HunyuanVideoI2V, + HunyuanVideo, + CosmosT2V, + CosmosI2V, + CosmosT2IPredict2, + CosmosI2VPredict2, + ZImagePixelSpace, + ZImage, + Lumina2, + WAN22_T2V, + WAN21_T2V, + WAN21_I2V, + WAN21_FunControl2V, + WAN21_Vace, + WAN21_Camera, + WAN22_Camera, + WAN22_S2V, + WAN21_HuMo, + WAN22_Animate, + WAN21_FlowRVS, + WAN21_SCAIL, + Hunyuan3Dv2mini, + Hunyuan3Dv2, + Hunyuan3Dv2_1, + HiDream, + Chroma, + ChromaRadiance, + ACEStep, + ACEStep15, + Omnigen2, + QwenImage, + Flux2, + Kandinsky5Image, + Kandinsky5, + Anima, + RT_DETR_v4, + ErnieImage, + SAM3, + SAM31, + CogVideoX_I2V, + CogVideoX_T2V, + SVD_img2vid, +] From fa7553138e3c75befe6aaf988048d4a0a95c1a32 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 1 May 2026 21:09:25 +0300 Subject: [PATCH 028/102] chore(api-nodes): remove Moonvalley API nodes (#13659) Signed-off-by: bigcat88 --- comfy_api_nodes/apis/moonvalley.py | 152 -------- comfy_api_nodes/nodes_moonvalley.py | 534 ---------------------------- 2 files changed, 686 deletions(-) delete mode 100644 comfy_api_nodes/apis/moonvalley.py delete mode 100644 comfy_api_nodes/nodes_moonvalley.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py deleted file mode 100644 index 7ec7a4ade..000000000 --- a/comfy_api_nodes/apis/moonvalley.py +++ /dev/null @@ -1,152 +0,0 @@ -from enum import Enum -from typing import Optional, Dict, Any - -from pydantic import BaseModel, Field, StrictBytes - - -class MoonvalleyPromptResponse(BaseModel): - error: Optional[Dict[str, Any]] = None - frame_conditioning: Optional[Dict[str, Any]] = None - id: Optional[str] = None - inference_params: Optional[Dict[str, Any]] = None - meta: Optional[Dict[str, Any]] = None - model_params: Optional[Dict[str, Any]] = None - output_url: Optional[str] = None - prompt_text: Optional[str] = None - status: Optional[str] = None - - -class MoonvalleyTextToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 75, description='Number of cooldown steps (calculated based on num_frames)' - ) - fps: Optional[int] = Field( - 24, description='Frames per second of the generated video' - ) - guidance_scale: Optional[float] = Field( - 10, description='Guidance scale for generation control' - ) - height: Optional[int] = Field( - 1080, description='Height of the generated video in pixels' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - num_frames: Optional[int] = Field(64, description='Number of frames to generate') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 0, description='Number of warmup steps (calculated based on num_frames)' - ) - width: Optional[int] = Field( - 1920, description='Width of the generated video in pixels' - ) - - -class MoonvalleyTextToVideoRequest(BaseModel): - image_url: Optional[str] = None - inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None - prompt_text: Optional[str] = None - webhook_url: Optional[str] = None - - -class MoonvalleyUploadFileRequest(BaseModel): - file: Optional[StrictBytes] = None - - -class MoonvalleyUploadFileResponse(BaseModel): - access_url: Optional[str] = None - - -class MoonvalleyVideoToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 36, description='Number of cooldown steps (calculated based on num_frames)' - ) - guidance_scale: Optional[float] = Field( - 15, description='Guidance scale for generation control' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 24, description='Number of warmup steps (calculated based on num_frames)' - ) - - -class ControlType(str, Enum): - motion_control = 'motion_control' - pose_control = 'pose_control' - - -class MoonvalleyVideoToVideoRequest(BaseModel): - control_type: ControlType = Field( - ..., description='Supported types for video control' - ) - inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None - prompt_text: str = Field(..., description='Describes the video to generate') - video_url: str = Field(..., description='Url to control video') - webhook_url: Optional[str] = Field( - None, description='Optional webhook URL for notifications' - ) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py deleted file mode 100644 index 78a230529..000000000 --- a/comfy_api_nodes/nodes_moonvalley.py +++ /dev/null @@ -1,534 +0,0 @@ -import logging - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.moonvalley import ( - MoonvalleyPromptResponse, - MoonvalleyTextToVideoInferenceParams, - MoonvalleyTextToVideoRequest, - MoonvalleyVideoToVideoInferenceParams, - MoonvalleyVideoToVideoRequest, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_video_output, - poll_op, - sync_op, - trim_video, - upload_images_to_comfyapi, - upload_video_to_comfyapi, - validate_container_format_is_mp4, - validate_image_dimensions, - validate_string, -) - -API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" -API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" -API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" -API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video" -API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video" - -MIN_WIDTH = 300 -MIN_HEIGHT = 300 - -MAX_WIDTH = 10000 -MAX_HEIGHT = 10000 - -MIN_VID_WIDTH = 300 -MIN_VID_HEIGHT = 300 - -MAX_VID_WIDTH = 10000 -MAX_VID_HEIGHT = 10000 - -MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing - -MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 - - -def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: - """Verifies that the initial response contains a task ID.""" - return bool(response.id) - - -def validate_task_creation_response(response) -> None: - if not is_valid_task_creation_response(response): - error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" - logging.error(error_msg) - raise RuntimeError(error_msg) - - -def validate_video_to_video_input(video: Input.Video) -> Input.Video: - """ - Validates and processes video input for Moonvalley Video-to-Video generation. - - Args: - video: Input video to validate - - Returns: - Validated and potentially trimmed video - - Raises: - ValueError: If video doesn't meet requirements - MoonvalleyApiError: If video duration is too short - """ - width, height = _get_video_dimensions(video) - _validate_video_dimensions(width, height) - validate_container_format_is_mp4(video) - - return _validate_and_trim_duration(video) - - -def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: - """Extracts video dimensions with error handling.""" - try: - return video.get_dimensions() - except Exception as e: - logging.error("Error getting dimensions of video: %s", e) - raise ValueError(f"Cannot get video dimensions: {e}") from e - - -def _validate_video_dimensions(width: int, height: int) -> None: - """Validates video dimensions meet Moonvalley V2V requirements.""" - supported_resolutions = { - (1920, 1080), - (1080, 1920), - (1152, 1152), - (1536, 1152), - (1152, 1536), - } - - if (width, height) not in supported_resolutions: - supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") - - -def _validate_and_trim_duration(video: Input.Video) -> Input.Video: - """Validates video duration and trims to 5 seconds if needed.""" - duration = video.get_duration() - _validate_minimum_duration(duration) - return _trim_if_too_long(video, duration) - - -def _validate_minimum_duration(duration: float) -> None: - """Ensures video is at least 5 seconds long.""" - if duration < 5: - raise ValueError("Input video must be at least 5 seconds long.") - - -def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: - """Trims video to 5 seconds if longer.""" - if duration > 5: - return trim_video(video, 5) - return video - - -def parse_width_height_from_res(resolution: str): - # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict - res_map = { - "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, - "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, - "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, - "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, - } - return res_map.get(resolution, {"width": 1920, "height": 1080}) - - -def parse_control_parameter(value): - control_map = { - "Motion Transfer": "motion_control", - "Canny": "canny_control", - "Pose Transfer": "pose_control", - "Depth": "depth_control", - } - return control_map.get(value, control_map["Motion Transfer"]) - - -async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: - return await poll_op( - cls, - ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), - response_model=MoonvalleyPromptResponse, - status_extractor=lambda r: (r.status if r and r.status else None), - poll_interval=16.0, - max_poll_attempts=240, - ) - - -class MoonvalleyImg2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyImg2VideoNode", - display_name="Moonvalley Marey Image to Video", - category="api node/video/Moonvalley Marey", - description="Moonvalley Marey Image to Video Node", - inputs=[ - IO.Image.Input( - "image", - tooltip="The reference image used to generate the video", - ), - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - # "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.5, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=True, - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Number of denoising steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - width=width_height["width"], - height=width_height["height"], - use_negative_prompts=True, - ) - - # Get MIME type from tensor - assuming PNG format for image tensors - mime_type = "image/png" - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) - - -class MoonvalleyVideo2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyVideo2VideoNode", - display_name="Moonvalley Marey Video to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - tooltip="Describes the video to generate", - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=False, - ), - IO.Video.Input( - "video", - tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " - "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", - ), - IO.Combo.Input( - "control_type", - options=["Motion Transfer", "Pose Transfer"], - default="Motion Transfer", - optional=True, - ), - IO.Int.Input( - "motion_intensity", - default=100, - min=0, - max=100, - step=1, - tooltip="Only used if control_type is 'Motion Transfer'", - optional=True, - ), - IO.Int.Input( - "steps", - default=60, - min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24) - max=100, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Number of inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 2.25}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - seed: int, - video: Input.Video | None = None, - control_type: str = "Motion Transfer", - motion_intensity: int | None = 100, - steps=60, - prompt_adherence=4.5, - ) -> IO.NodeOutput: - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(cls, validated_video) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - - # Only include motion_intensity for Motion Transfer - control_params = {} - if control_type == "Motion Transfer" and motion_intensity is not None: - control_params["motion_intensity"] = motion_intensity - - inference_params = MoonvalleyVideoToVideoInferenceParams( - negative_prompt=negative_prompt, - seed=seed, - control_params=control_params, - steps=steps, - guidance_scale=prompt_adherence, - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyVideoToVideoRequest( - control_type=parse_control_parameter(control_type), - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyTxt2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyTxt2VideoNode", - display_name="Moonvalley Marey Text to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.0, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Random seed value", - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - num_frames=128, - width=width_height["width"], - height=width_height["height"], - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - MoonvalleyImg2VideoNode, - MoonvalleyTxt2VideoNode, - MoonvalleyVideo2VideoNode, - ] - - -async def comfy_entrypoint() -> MoonvalleyExtension: - return MoonvalleyExtension() From 10b45a71cdac2898693bb42aa0a21e2cb23a2daa Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sat, 2 May 2026 03:11:30 +0800 Subject: [PATCH 029/102] chore: update workflow templates to v0.9.66 (#13662) Co-authored-by: Jedrzej Kosinski --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cb85d970b..932034076 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.65 +comfyui-workflow-templates==0.9.66 comfyui-embedded-docs==0.4.4 torch torchsde From cf758bd2566a04a156496fa77ec2c7fa76ff8873 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 1 May 2026 22:48:41 +0300 Subject: [PATCH 030/102] chore(api-nodes): increase default timeout for partner API node tasks (#13663) Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_bytedance.py | 3 --- comfy_api_nodes/nodes_hitpaw.py | 2 -- comfy_api_nodes/nodes_kling.py | 3 --- comfy_api_nodes/nodes_magnific.py | 5 ----- comfy_api_nodes/nodes_topaz.py | 1 - comfy_api_nodes/nodes_vidu.py | 3 +-- comfy_api_nodes/nodes_wan.py | 1 - comfy_api_nodes/nodes_wavespeed.py | 2 -- comfy_api_nodes/util/client.py | 4 ++-- 9 files changed, 3 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index fee0ab888..2f241a775 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1403,7 +1403,6 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1585,7 +1584,6 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1907,7 +1905,6 @@ class ByteDance2ReferenceNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), poll_interval=9, - max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index 488080a74..bca5170e4 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -178,7 +178,6 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): status_extractor=lambda x: x.data.status, price_extractor=lambda x: request_price, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) @@ -324,7 +323,6 @@ class HitPawVideoEnhance(IO.ComfyNode): status_extractor=lambda x: x.data.status, price_extractor=lambda x: request_price, poll_interval=10.0, - max_poll_attempts=320, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 709b3726c..efd58fac3 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3062,7 +3061,6 @@ class KlingVideoNode(IO.ComfyNode): cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3188,7 +3186,6 @@ class KlingFirstLastFrameNode(IO.ComfyNode): cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 0f53208d4..38b881fea 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -230,7 +230,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): status_extractor=lambda x: x.status, price_extractor=lambda _: price_usd, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -391,7 +390,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): status_extractor=lambda x: x.status, price_extractor=lambda _: price_usd, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -541,7 +539,6 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -782,7 +779,6 @@ class MagnificImageRelightNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) @@ -924,7 +920,6 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): response_model=TaskResponse, status_extractor=lambda x: x.status, poll_interval=10.0, - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index b18b31af1..fe3666ec9 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -453,7 +453,6 @@ class TopazVideoEnhance(IO.ComfyNode): progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), poll_interval=10.0, - max_poll_attempts=320, ) return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index f04407eb5..8d90cefeb 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -38,7 +38,7 @@ async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, - max_poll_attempts: int = 320, + max_poll_attempts: int = 480, ) -> list[TaskResult]: task_creation_response = await sync_op( cls, @@ -1097,7 +1097,6 @@ class ViduExtendVideoNode(IO.ComfyNode): video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), images=[image_url] if image_url else None, ), - max_poll_attempts=480, ) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 7d7466fb6..68061bb5c 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -818,7 +818,6 @@ class WanReferenceVideoApi(IO.ComfyNode): response_model=VideoTaskStatusResponse, status_extractor=lambda x: x.output.task_status, poll_interval=6, - max_poll_attempts=280, ) return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index c59fafd3b..65e45f60a 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -84,7 +84,6 @@ class WavespeedFlashVSRNode(IO.ComfyNode): response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, - max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( @@ -156,7 +155,6 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, - max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index b0cf97ae4..a0b8d35e1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -148,7 +148,7 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 160, + max_poll_attempts: int = 480, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, @@ -254,7 +254,7 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 160, + max_poll_attempts: int = 480, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, From 63103d519ec960701438e8617452ef64b02609c7 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Fri, 1 May 2026 14:16:41 -0700 Subject: [PATCH 031/102] Remove IPEX and clean up checks and add missing synchronize during empty cache. (#13653) --- comfy/cli_args.py | 1 - comfy/model_management.py | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..cef1a5e6b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -90,7 +90,6 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") -parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") class LatentPreviewMethod(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index 95af40012..f86e2a4aa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -112,10 +112,6 @@ if args.directml is not None: # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. -try: - import intel_extension_for_pytorch as ipex # noqa: F401 -except: - pass try: _ = torch.xpu.device_count() @@ -583,9 +579,6 @@ class LoadedModel: real_model = self.model.model - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: - with torch.no_grad(): - real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.real_model = weakref.ref(real_model) self.model_finalizer = weakref.finalize(real_model, cleanup_models) @@ -1581,10 +1574,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 3): - return True - else: - return torch.xpu.get_device_properties(device).has_fp16 + return torch.xpu.get_device_properties(device).has_fp16 if is_ascend_npu(): return True @@ -1650,10 +1640,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 3): - return True - else: - return torch.xpu.is_bf16_supported() + return torch.xpu.is_bf16_supported() if is_ascend_npu(): return True @@ -1784,6 +1771,7 @@ def soft_empty_cache(force=False): if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif is_intel_xpu(): + torch.xpu.synchronize() torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() From b5921c8ac2d3cd1171bb33245f4343b1471224ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 2 May 2026 00:17:25 +0300 Subject: [PATCH 032/102] SDPose: resize fix (#13656) --- comfy_extras/nodes_sdpose.py | 38 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 7d54967d5..96b6821bd 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode): total_images = image.shape[0] captured_feat = None - model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 - model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768 + model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024 def _resize_to_model(imgs): - """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left).""" + """Stretch BHWC images to (model_h, model_w), model expects no aspect preservation.""" h, w = imgs.shape[-3], imgs.shape[-2] - scale = min(model_h / h, model_w / w) - sh, sw = int(round(h * scale)), int(round(w * scale)) - pt, pl = (model_h - sh) // 2, (model_w - sw) // 2 + method = "area" if (model_h <= h and model_w <= w) else "bilinear" chw = imgs.permute(0, 3, 1, 2).float() - scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled") - padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) - padded[:, :, pt:pt + sh, pl:pl + sw] = scaled - return padded.permute(0, 2, 3, 1), scale, pt, pl + scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled") + return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h - def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0): + def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0): """Remap keypoints from model space back to original image space.""" kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) invalid = kp[..., 0] < 0 - kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x - kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y + kp[..., 0] = kp[..., 0] / scale_x + offset_x + kp[..., 1] = kp[..., 1] / scale_y + offset_y kp[invalid] = -1 return kp @@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode): continue crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) - crop_resized, scale, pad_top, pad_left = _resize_to_model(crop) + crop_resized, sx, sy = _resize_to_model(crop) latent_crop = vae.encode(crop_resized) kp_batch, sc_batch = _run_on_latent(latent_crop) - kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1) + kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1) img_keypoints.append(kp) img_scores.append(sc_batch[0]) else: - img_resized, scale, pad_top, pad_left = _resize_to_model(img) + img_resized, sx, sy = _resize_to_model(img) latent_img = vae.encode(img_resized) kp_batch, sc_batch = _run_on_latent(latent_img) - img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left)) + img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy)) img_scores.append(sc_batch[0]) all_keypoints.append(img_keypoints) @@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode): else: # full-image mode, batched for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"): - batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size]) + batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size]) latent_batch = vae.encode(batch_resized) kp_batch, sc_batch = _run_on_latent(latent_batch) for kp, sc in zip(kp_batch, sc_batch): - all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)]) + all_keypoints.append([_remap_keypoints(kp, sx, sy)]) all_scores.append([sc]) pbar.update(len(kp_batch)) @@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode): scale = min(output_width / crop_w, output_height / crop_h) scaled_w = int(round(crop_w * scale)) scaled_h = int(round(crop_h * scale)) - scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled") pad_left = (output_width - scaled_w) // 2 pad_top = (output_height - scaled_h) // 2 resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device) resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled else: # "stretch" - resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled") crops.append(resized) if not crops: From 0230e0e7cc389979e509cd6237a7b9244798e69c Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sat, 2 May 2026 06:37:18 +0800 Subject: [PATCH 033/102] Adding kijai (#13664) Co-authored-by: Jedrzej Kosinski --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index e693955a0..946dbf946 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,2 @@ # Admins -* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 +* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai From 67f6cb35273d00278d2b1ef2a8c3efe21238f22d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 1 May 2026 17:19:32 -0700 Subject: [PATCH 034/102] List all the portable downloads in the README section. (#13666) --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f05311421..3b5114633 100644 --- a/README.md +++ b/README.md @@ -193,13 +193,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. -#### Alternative Downloads: +#### All Official Portable Downloads: [Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) +[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) -[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). +[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above). + +[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From 3e3ed8cc2aaa142711e89e1e799956e1e57af62f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 1 May 2026 17:19:46 -0700 Subject: [PATCH 035/102] Add script in AMD portable to launch with dynamic vram. (#13667) --- ...ble_smart_memory.bat => run_amd_gpu_enable_dynamic_vram.bat} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename .ci/windows_amd_base_files/{run_amd_gpu_disable_smart_memory.bat => run_amd_gpu_enable_dynamic_vram.bat} (66%) diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat similarity index 66% rename from .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat rename to .ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat index cece0aeb2..94ad31942 100755 --- a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat +++ b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat @@ -1,2 +1,2 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram pause From 783782d5d742a7bc38dd0b661e030813bc50839a Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 3 May 2026 09:23:24 +1000 Subject: [PATCH 036/102] Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618) * mm: Use Aimdo raw allocator for cast buffers pytorch manages allocation of growing buffers on streams poorly. Pyt has no windows support for the expandable segments allocator (which is the right tool for this job), while also segmenting the memory by stream such that it can be generally re-used. So kick the problem to aimdo which can just grow a virtual region thats freed per stream. * plan * ops: move cpu handler up to the caller * ops: split up prefetch from weight prep block prefetching API Split up the casting and weight formating/lora stuff in prep for arbitrary prefetch support. * ops: implement block prefetching API allow a model to construct a prefetch list and operate it for increased async offload. * ltxv2: Implement block prefetching * Implement lora async offload Implement async offload of loras. --- comfy/ldm/lightricks/av_model.py | 5 + comfy/lora.py | 15 +++ comfy/model_base.py | 5 + comfy/model_management.py | 22 +++- comfy/model_patcher.py | 13 ++- comfy/model_prefetch.py | 65 +++++++++++ comfy/ops.py | 181 ++++++++++++++++++++++--------- execution.py | 2 + 8 files changed, 251 insertions(+), 57 deletions(-) create mode 100644 comfy/model_prefetch.py diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..3fb87b4a3 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/lora.py b/comfy/lora.py index e4337c729..db8f16bcb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_base.py b/comfy/model_base.py index 50dab5782..b61a2aa09 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module): if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + transformer_options = transformer_options.copy() + transformer_options["prefetch_dynamic_vbars"] = ( + self.current_patcher is not None and self.current_patcher.is_dynamic() + ) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) diff --git a/comfy/model_management.py b/comfy/model_management.py index f86e2a4aa..02ad66656 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -1175,6 +1176,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() + STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..0ad35deb5 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,65 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device, transformer_options): + if (not transformer_options.get("prefetch_dynamic_vbars", False) + or comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..96db1411c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #vbar doesn't support CPU weights, but some custom nodes have weird paths - #that might switch the layer to the CPU and expect it to work. We have to take - #a clone conservatively as we are mmapped and some SFT files are packed misaligned - #If you are a custom node author reading this, please move your layer to the GPU - #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = None - if s.bias is not None: - bias = s.bias.to(dtype=bias_dtype, copy=True) - return weight, bias, (None, None, None) - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: - offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - comfy.model_management.sync_stream(device, offload_stream) - if cast_dest is not None: + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch + + return offload_stream + + +def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant): + + prefetch = getattr(s, "_prefetch", None) + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + if bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) - if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream): diff --git a/execution.py b/execution.py index 5a6d3404c..654db8426 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: From ef6722f6be7bf073d225d21da47354905a6abd2b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 May 2026 17:34:27 -0700 Subject: [PATCH 037/102] Some cleanups to the load image node. (#13677) --- nodes.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/nodes.py b/nodes.py index 99dc07227..710cccffe 100644 --- a/nodes.py +++ b/nodes.py @@ -1694,26 +1694,27 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + components = InputImpl.VideoFromFile(image_path).get_components() if components.images.shape[0] > 0: - return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) + return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device)) + # This code is left here to handle animated webp which pyav does not support loading img = node_helpers.pillow(Image.open, image_path) output_images = [] output_masks = [] w, h = None, None - dtype = comfy.model_management.intermediate_dtype() - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") if len(output_images) == 0: @@ -1728,25 +1729,15 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image.to(dtype=dtype)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype)) - if img.format == "MPO": - break # ignore all frames except the first one for MPO format + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] - - return (output_image, output_mask) + return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype)) @classmethod def IS_CHANGED(s, image): From 1d23a875ed0d4644538265635c1259be08a3370e Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sun, 3 May 2026 10:06:55 +0800 Subject: [PATCH 038/102] chore: update workflow templates to v0.9.68 (#13678) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 932034076..32826e25a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.66 +comfyui-workflow-templates==0.9.68 comfyui-embedded-docs==0.4.4 torch torchsde From f756d801a1e5fbafe81cdfdf8a1c0aadf54c9bea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 3 May 2026 05:29:00 +0300 Subject: [PATCH 039/102] [Partner Nodes] Topaz Astra 2 model (#13672) * feat(api-nodes): add Topaz Astra 2 model Signed-off-by: bigcat88 * feat(api-nodes): make Astra 2 the default Topaz upscaler model Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so "Astra 2" appears first, surfacing it as the default selection. --------- Signed-off-by: bigcat88 Co-authored-by: Marwan Mostafa --- comfy_api_nodes/apis/topaz.py | 9 +- comfy_api_nodes/nodes_topaz.py | 361 ++++++++++++++++++++++++++++++++- 2 files changed, 365 insertions(+), 5 deletions(-) diff --git a/comfy_api_nodes/apis/topaz.py b/comfy_api_nodes/apis/topaz.py index a9e6235a7..f91980e3d 100644 --- a/comfy_api_nodes/apis/topaz.py +++ b/comfy_api_nodes/apis/topaz.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional from pydantic import BaseModel, Field @@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel): grain: Optional[float] = Field(None, description="Grain after AI model processing") grainSize: Optional[float] = Field(None, description="Size of generated grain") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") - creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)") + sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness") + realism: float | None = Field(None, description="ast-2 realism control") class OutputInformationVideo(BaseModel): @@ -90,7 +93,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): source: CreateVideoRequestSource = Field(...) - filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index fe3666ec9..e79c16d3c 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -36,11 +36,15 @@ from comfy_api_nodes.util import ( ) UPSCALER_MODELS_MAP = { + "Astra 2": "ast-2", "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", "Starlight Precise 2.5": "slp-2.5", } +AST2_MAX_FRAMES = 9000 +AST2_MAX_FRAMES_WITH_PROMPT = 450 + class TopazImageEnhance(IO.ComfyNode): @classmethod @@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TopazVideoEnhance", - display_name="Topaz Video Enhance", + display_name="Topaz Video Enhance (Legacy)", category="api node/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), - IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input( + "upscaler_model", + options=[ + "Starlight (Astra) Fast", + "Starlight (Astra) Creative", + "Starlight Precise 2.5", + ], + ), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", @@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) +class TopazVideoEnhanceV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhanceV2", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.DynamicCombo.Input( + "upscaler_model", + options=[ + IO.DynamicCombo.Option( + "Astra 2", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Float.Input( + "creativity", + default=0.5, + min=0.0, + max=1.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Creative strength of the upscale.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional descriptive (not instructive) scene prompt." + f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.", + ), + IO.Float.Input( + "sharp", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pre-enhance sharpness: " + "0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.", + advanced=True, + ), + IO.Float.Input( + "realism", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pulls output toward photographic realism." + "Leave at 0 for the model default.", + advanced=True, + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Fast", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Creative", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Combo.Input( + "creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creative strength of the upscale.", + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight Precise 2.5", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])], + ), + IO.DynamicCombo.Option("Disabled", []), + ], + ), + IO.DynamicCombo.Input( + "interpolation_model", + options=[ + IO.DynamicCombo.Option("Disabled", []), + IO.DynamicCombo.Option( + "apo-8", + [ + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + ), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + advanced=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + advanced=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + advanced=True, + ), + ], + ), + ], + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=[ + "upscaler_model", + "upscaler_model.upscaler_resolution", + "interpolation_model", + ]), + expr=""" + ( + $model := $lookup(widgets, "upscaler_model"); + $res := $lookup(widgets, "upscaler_model.upscaler_resolution"); + $interp := $lookup(widgets, "interpolation_model"); + $is4k := $contains($res, "4k"); + $hasInterp := $interp != "disabled"; + $rates := { + "starlight (astra) fast": {"hd": 0.43, "uhd": 0.85}, + "starlight precise 2.5": {"hd": 0.70, "uhd": 1.54}, + "astra 2": {"hd": 1.72, "uhd": 2.85}, + "starlight (astra) creative": {"hd": 2.25, "uhd": 3.99} + }; + $surcharge := $is4k ? 0.28 : 0.14; + $entry := $lookup($rates, $model); + $base := $is4k ? $entry.uhd : $entry.hd; + $hi := $base + ($hasInterp ? $surcharge : 0); + $model = "disabled" + ? {"type":"text","text":"Interpolation only"} + : ($hasInterp + ? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"} + : {"type":"text","text":"~" & $string($base) & " credits/src frame"}) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + upscaler_model: dict, + interpolation_model: dict, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + upscaler_choice = upscaler_model["upscaler_model"] + interpolation_choice = interpolation_model["interpolation_model"] + if upscaler_choice == "Disabled" and interpolation_choice == "Disabled": + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_choice != "Disabled": + if "1080p" in upscaler_model["upscaler_resolution"]: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 + model_id = UPSCALER_MODELS_MAP[upscaler_choice] + if model_id == "slc-1": + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + isOptimizedMode=True, + ) + ) + elif model_id == "ast-2": + n_frames = video.get_frame_count() + ast2_prompt = (upscaler_model["prompt"] or "").strip() + if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT: + raise ValueError( + f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames " + f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip." + ) + if n_frames > AST2_MAX_FRAMES: + raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.") + realism = upscaler_model["realism"] + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + prompt=(ast2_prompt or None), + sharp=upscaler_model["sharp"], + realism=(realism if realism > 0 else None), + ) + ) + else: + filters.append(VideoEnhancementFilter(model=model_id)) + if interpolation_choice != "Disabled": + target_frame_rate = interpolation_model["interpolation_frame_rate"] + filters.append( + VideoFrameInterpolationFilter( + model=interpolation_choice, + slowmo=interpolation_model["interpolation_slowmo"], + fps=interpolation_model["interpolation_frame_rate"], + duplicate=interpolation_model["interpolation_duplicate"], + duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"], + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=video.get_frame_count(), + frameRate=src_frame_rate, + resolution=Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( + uploadResults=[ + VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + class TopazExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TopazImageEnhance, TopazVideoEnhance, + TopazVideoEnhanceV2, ] From be95871adccfac92a91ebdc06e52a85511f7b85c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 3 May 2026 05:46:15 +0300 Subject: [PATCH 040/102] feat: Gemma4 text generation support (CORE-30) (#13376) * initial gemma4 support * parity with reference implementation outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize * Cleanup, video fixes * cleanup, enable fused rms norm by default * update comment * Cleanup * Update sd.py * Various fixes * Add fp8 scaled embedding support * small fixes * Translate think tokens * Fix image encoder attention mask type So it works with basic attention * Handle thinking tokens different only for Gemma4 * Code cleanup * Update nodes_textgen.py * Use embed scale class instead of buffer Slight difference to HF, but technically more accurate and simpler code * Default to fused rms_norm * Update gemma4.py --- comfy/ldm/modules/attention.py | 24 +- comfy/ops.py | 87 +++ comfy/rmsnorm.py | 1 + comfy/sd.py | 17 + comfy/text_encoders/gemma4.py | 1298 ++++++++++++++++++++++++++++++++ comfy/text_encoders/llama.py | 40 +- comfy/text_encoders/lt.py | 3 +- comfy/text_encoders/lumina2.py | 3 +- comfy/text_encoders/qwen35.py | 2 - comfy/utils.py | 7 - comfy_extras/nodes_textgen.py | 13 +- 11 files changed, 1453 insertions(+), 42 deletions(-) create mode 100644 comfy/text_encoders/gemma4.py diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5) + if model_management.xformers_enabled(): import xformers import xformers.ops @@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]: + n_rep = q.shape[-3] // k.shape[-3] + k = k.repeat_interleave(n_rep, dim=-3) + v = v.repeat_interleave(n_rep, dim=-3) + + scale = kwargs.get("scale", dim_head ** -0.5) h = heads if skip_reshape: @@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, b, _, dim_head = query.shape dim_head //= heads + if "scale" in kwargs: + # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head)) + query = query * (kwargs["scale"] * dim_head ** 0.5) + if skip_reshape: query = query.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head) @@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + scale = kwargs.get("scale", dim_head ** -0.5) if skip_reshape: q, k, v = map( @@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) + # Pass through extra SDPA kwargs (scale, enable_gqa) if provided + # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above + sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",) + sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys} + if SDP_BATCH_LIMIT >= b: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=m, - dropout_p=0.0, is_causal=False + dropout_p=0.0, is_causal=False, **sdpa_extra ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/ops.py b/comfy/ops.py index 96db1411c..4f0338346 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class Embedding(manual_cast.Embedding): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + weight_key = f"{prefix}weight" + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + # Only fp8 makes sense for embeddings (per-row dequant via index select). + # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. + quant_format = layer_conf.get("format", None) if layer_conf is not None else None + if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + self.quant_format = quant_format + qconfig = QUANT_ALGOS[quant_format] + layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + weight = state_dict.pop(weight_key) + manually_loaded_keys = [weight_key] + + scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(scale_key, None) + if scale is not None: + scale = scale.float() + manually_loaded_keys.append(scale_key) + + params = layout_cls.Params( + scale=scale if scale is not None else torch.ones((), dtype=torch.float32), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.num_embeddings, self.embedding_dim), + ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), + requires_grad=False) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + else: + if layer_conf is not None: + state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight') or self.weight is None: + return sd + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + + quant_conf = {"format": self.quant_format} + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + else: + sd["{}weight".format(prefix)] = self.weight + return sd + + def forward_comfy_cast_weights(self, input, out_dtype=None): + weight = self.weight + + # Optimized path: lookup in fp8, dequantize only the selected rows. + if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0: + qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True) + if isinstance(qdata, QuantizedTensor): + scale = qdata._params.scale + qdata = qdata._qdata + else: + scale = None + + x = torch.nn.functional.embedding( + input, qdata, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + uncast_bias_weight(self, qdata, None, offload_stream) + target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype + x = x.to(dtype=target_dtype) + if scale is not None and scale != 1.0: + x = x * scale.to(dtype=target_dtype) + return x + + # Fallback for non-quantized or weight_function (LoRA) case + return super().forward_comfy_cast_weights(input, out_dtype=out_dtype) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index ab7cf14fa..e54be98d6 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,6 +3,7 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm +# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding). def rms_norm(x, weight=None, eps=1e-6): if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) diff --git a/comfy/sd.py b/comfy/sd.py index ee66490f5..9fce0e7d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -65,6 +65,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie +import comfy.text_encoders.gemma4 import comfy.model_patcher import comfy.lora @@ -1271,6 +1272,9 @@ class TEModel(Enum): QWEN35_9B = 26 QWEN35_27B = 27 MINISTRAL_3_3B = 28 + GEMMA_4_E4B = 29 + GEMMA_4_E2B = 30 + GEMMA_4_31B = 31 def detect_te_model(sd): @@ -1296,6 +1300,12 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.59.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_4_31B + if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E4B + if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E2B if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): + variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, + TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, + TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) + clip_target.tokenizer = variant.tokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py new file mode 100644 index 000000000..f050061ed --- /dev/null +++ b/comfy/text_encoders/gemma4.py @@ -0,0 +1,1298 @@ +import torch +import torch.nn as nn +import numpy as np +from dataclasses import dataclass +import math + +from comfy import sd1_clip +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.rmsnorm import rms_norm +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding + + +# Intentional minor divergences from transformers -reference implementation: +# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster +# - Input image and audio resizing/resampling slightly different numerically + + +GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} + +@dataclass +class Gemma4Config: + vocab_size: int = 262144 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 42 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma4" + head_dim = 256 + global_head_dim = 512 + rms_norm_add = False + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [512, 512, 512, 512, 512, False] + rope_scale = None + partial_rotary_factor: float = 0.25 + final_norm: bool = True + lm_head: bool = False + final_logit_softcapping: float = 30.0 + hidden_size_per_layer_input: int = 256 + num_kv_shared_layers: int = 18 + use_double_wide_mlp: bool = False + stop_tokens = [1, 50, 106] + vision_config = GEMMA4_VISION_CONFIG + audio_config = GEMMA4_AUDIO_CONFIG + mm_tokens_per_image = 280 + +@dataclass +class Gemma4_E2B_Config(Gemma4Config): + hidden_size: int = 1536 + intermediate_size: int = 6144 + num_hidden_layers: int = 35 + num_key_value_heads: int = 1 + sliding_attention = [512, 512, 512, 512, False] + num_kv_shared_layers: int = 20 + use_double_wide_mlp: bool = True + +@dataclass +class Gemma4_31B_Config(Gemma4Config): + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + hidden_size_per_layer_input: int = 0 + num_kv_shared_layers: int = 0 + audio_config = None + vision_config = GEMMA4_VISION_31B_CONFIG + + +# unfused RoPE as addcmul_ RoPE diverges from reference code +def _apply_rotary_pos_emb(x, freqs_cis): + cos, sin = freqs_cis[0], freqs_cis[1] + half = x.shape[-1] // 2 + out = x * cos + out[..., :half] -= x[..., half:] * sin[..., :half] + out[..., half:] += x[..., :half] * sin[..., half:] + return out + +class Gemma4Attention(nn.Module): + def __init__(self, config, head_dim, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.head_dim = head_dim + self.inner_size = self.num_heads * head_dim + + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + self.q_norm = None + self.k_norm = None + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + freqs_cis=None, + past_key_value=None, + sliding_window=None, + shared_kv=None, + ): + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states) + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + + if shared_kv is not None: + xk, xv = shared_kv + # Apply RoPE to Q only (K already has RoPE from source layer) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + present_key_value = None + shareable_kv = None + else: + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + if self.k_norm is not None: + xk = self.k_norm(xk) + xv = rms_norm(xv) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + xk = _apply_rotary_pos_emb(xk, freqs_cis) + + present_key_value = None + if past_key_value is not None: + cumulative_len = 0 + if len(past_key_value) > 0: + past_key, past_value, cumulative_len = past_key_value + xk = torch.cat((past_key, xk), dim=2) + xv = torch.cat((past_value, xv), dim=2) + new_cumulative = cumulative_len + seq_length + if sliding_window is not None and xk.shape[2] > sliding_window - 1: + cache_k = xk[:, :, -(sliding_window - 1):] + cache_v = xv[:, :, -(sliding_window - 1):] + else: + cache_k = xk + cache_v = xv + present_key_value = (cache_k, cache_v, new_cumulative) + + # KV for sharing: full xk/xv that SDPA sees (not evicted cache) + shareable_kv = (xk, xv) + + # GQA: pass unexpanded KV with enable_gqa when no sliding mask, + # expand heads when sliding mask is present + # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences + expand_kv = (self.num_heads != self.num_kv_heads and + sliding_window is not None and + xk.shape[2] >= sliding_window) + if expand_kv: + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}) + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs) + + return self.o_proj(output), present_key_value, shareable_kv + + +class TransformerBlockGemma4(nn.Module): + def __init__(self, config, index, device=None, dtype=None, ops=None): + super().__init__() + if config.sliding_attention is not None: + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + head_dim = config.head_dim if self.sliding_attention else config.global_head_dim + + self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) + + num_kv_shared = config.num_kv_shared_layers + first_kv_shared = config.num_hidden_layers - num_kv_shared + mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) + self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) + else: + self.layer_scalar = None + + def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): + sliding_window = None + if self.sliding_attention: + sliding_window = self.sliding_attention + # For prefill > sliding window, add sliding window restriction to the causal mask. + if x.shape[1] > self.sliding_attention: + sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min) + attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + + residual = x + x = self.input_layernorm(x) + x, present_key_value, shareable_kv = self.self_attn( + hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv, + ) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + + if self.hidden_size_per_layer_input and per_layer_input is not None: + residual = x + x = self.per_layer_input_gate(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = x * per_layer_input + x = self.per_layer_projection(x) + x = self.post_per_layer_input_norm(x) + x = residual + x + + if self.layer_scalar is not None: + x = x * self.layer_scalar + + return x, present_key_value, shareable_kv + + +class Gemma4Transformer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.config = config + + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + + self.layers = nn.ModuleList([ + TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None + + # Precompute RoPE inv_freq on CPU to match reference code's exact value + rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) + nope_global = config.global_head_dim // 2 - rope_angles_global + global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim)) + if nope_global > 0: + global_inv = torch.cat([global_inv, torch.zeros(nope_global)]) + self.register_buffer("_global_inv_freq", global_inv, persistent=False) + + sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim)) + self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) + + # Per-layer input mechanism + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype) + self.per_layer_model_projection = ops.Linear( + config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, device=device, dtype=dtype) + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, eps=config.rms_norm_eps, + device=device, dtype=dtype) + + def get_past_len(self, past_key_values): + for kv in past_key_values: + if len(kv) >= 3: + return kv[2] + return 0 + + def _freqs_from_inv(self, inv_freq, position_ids, device, dtype): + """Compute cos/sin from stored inv_freq""" + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device) + pos_exp = position_ids[:, None, :].float() + freqs = (inv_exp @ pos_exp).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype) + + def compute_freqs_cis(self, position_ids, device, dtype=None): + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype) + return [global_freqs, sliding_freqs] + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None, + past_key_values=None, input_ids=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = self.get_past_len(past_key_values) + + if position_ids is None: + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype) + + mask = None + min_val = torch.finfo(x.dtype).min + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), min_val) + + if seq_len > 1: + causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device) + causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) + mask = mask + causal_mask if mask is not None else causal_mask + + # Per-layer inputs + per_layer_inputs = None + if self.hidden_size_per_layer_input: + num_layers = self.config.num_hidden_layers + hpl = self.hidden_size_per_layer_input + per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5)) + per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl)) + if input_ids is not None and input_ids.shape[1] == x.shape[1]: + per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) + per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5) + else: + per_layer_inputs = per_layer_proj + + # KV sharing: later layers reuse KV from the last non-shared sliding/global layer + num_kv_shared = self.config.num_kv_shared_layers + first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers + shared_sliding_kv = None # KV from last non-shared sliding layer + shared_global_kv = None # KV from last non-shared global layer + + intermediate = None + next_key_values = [] + for i, layer in enumerate(self.layers): + past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None + + layer_kwargs = {} + if per_layer_inputs is not None: + layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :] + + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention + if i >= first_kv_shared and num_kv_shared > 0: + shared = shared_sliding_kv if is_sliding else shared_global_kv + if shared is not None: + layer_kwargs['shared_kv'] = shared + + x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs) + + next_key_values.append(current_kv if current_kv is not None else ()) + + # Only track the last sliding/global before the sharing boundary + if i < first_kv_shared and shareable_kv is not None: + if is_sliding: + shared_sliding_kv = shareable_kv + else: + shared_global_kv = shareable_kv + + if i == intermediate_output: + intermediate = x.clone() + + if self.norm is not None: + x = self.norm(x) + + if len(next_key_values) > 0: + return x, intermediate, next_key_values + return x, intermediate + + +class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): + """Common base for all Gemma4 variants: text model + vision.""" + def _init_model(self, config, dtype, device, operations): + self.num_layers = config.num_hidden_layers + self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations) + + def logits(self, x): + logits = super().logits(x) + cap = self.model.config.final_logit_softcapping + if cap: + logits = cap * torch.tanh(logits / cap) + return logits + + def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): + past_key_values = [] + for _ in range(self.model.config.num_hidden_layers): + past_key_values.append(()) + return past_key_values + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + max_soft_tokens = embed.get("max_soft_tokens", None) + vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) + return self.multi_modal_projector(vision_out), None + return None, None + + +class Gemma4AudioMixin: + """Adds audio support to a Gemma4 model.""" + def _init_audio(self, config, dtype, device, operations): + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations) + + def preprocess_embed(self, embed, device): + result, extra = super().preprocess_embed(embed, device) + if result is not None: + return result, extra + if embed["type"] == "audio": + audio = embed.pop("data").to(device, dtype=torch.float32) + audio_mask = embed.pop("mask", None) + if audio_mask is not None: + audio_mask = audio_mask.to(device) + audio_out = self.audio_model(audio, audio_mask=audio_mask) + return self.audio_projector(audio_out), None + return None, None + + +# Vision Encoder + +def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): + """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. + + Args: + head_dim: dimension per head (e.g. 64) + pixel_position_ids: [batch, num_patches, 2] with (x, y) coords + theta: RoPE base frequency + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim] + """ + rotary_dim_per_axis = head_dim // 2 + freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float() + inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis)) + + all_cos, all_sin = [], [] + for i in range(2): # x and y + dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches] + freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2] + emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim] + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + + cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim] + sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device) + return cos, sin + + +def _apply_vision_2d_rope(x, freqs): + """Apply 2D RoPE (multidimensional) to vision query/key states. + + Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently. + + x: [batch, heads, seq, head_dim] + freqs: (cos, sin) each [batch, seq, head_dim] + """ + cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] + sin = freqs[1].unsqueeze(1) + half = x.shape[-1] // 2 + a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half])) + b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:])) + return torch.cat([a, b], dim=-1) + + +class ClippedLinear(nn.Module): + """Linear layer with activation clipping (from quantization-aware training). + + Stores input_max/min and output_max/min as buffers loaded from checkpoint. + """ + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None): + super().__init__() + self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + + @property + def weight(self): + return self.linear.weight + + def forward(self, x): + x = x.clamp(min=self.input_min, max=self.input_max) + x = self.linear(x) + return x.clamp_(min=self.output_min, max=self.output_max) + + +class Gemma4VisionMLP(nn.Module): + """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config["intermediate_size"] + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) + + +class Gemma4VisionAttention(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) + + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) + + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + + def forward(self, x, freqs, attention_mask=None): + batch_size, seq_length, _ = x.shape + + xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + + xq = self.q_norm(xq).transpose(1, 2) + xk = self.k_norm(xk).transpose(1, 2) + xv = rms_norm(xv) + + xq = _apply_vision_2d_rope(xq, freqs) + xk = _apply_vision_2d_rope(xk, freqs) + + xv = xv.to(xq.dtype).transpose(1, 2) + + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0) + return self.o_proj(output) + + +class Gemma4VisionLayer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden = config["hidden_size"] + self.input_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) + self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + + def forward(self, x, freqs, attention_mask=None): + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, freqs, attention_mask=attention_mask) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class Gemma4PatchEmbedder(nn.Module): + """Patch embedding with learned 2D position embeddings via one-hot lookup.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + patch_size = config["patch_size"] + self.patch_size = patch_size + self.position_embedding_size = config.get("position_embedding_size", 10240) + + self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.position_embedding_table = nn.Parameter( + torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) + ) + + def forward(self, patches, pixel_position_ids): + """ + patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF) + pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding + """ + hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) + + clamped_positions = pixel_position_ids.clamp(min=0) + pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) + position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]] + + # Zero out position embeddings for padding patches (matching HF) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + + return hidden_states + position_embeddings + + +class Gemma4VisionEncoderLayers(nn.Module): + """Wrapper to produce state dict keys as encoder.layers.X.*""" + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([ + Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.config = config + self.hidden_size = config["hidden_size"] + self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]) + self.patch_size = config["patch_size"] + self.pooling_kernel_size = config.get("pooling_kernel_size", 3) + self.root_hidden_size = self.hidden_size ** 0.5 + + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops) + + def forward(self, pixel_values, max_soft_tokens=None): + """ + pixel_values: [B, C, H, W] in [0,1] range + max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches + """ + batch_size, _, height, width = pixel_values.shape + ps = self.patch_size + k = self.pooling_kernel_size + patches_h, patches_w = height // ps, width // ps + num_patches = patches_h * patches_w + output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k) + n_padding = output_length * k * k - num_patches + + # Patchify and build position grid + patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps) + patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1) + grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') + position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) + + # Append zero-pixel padding with (-1,-1) positions + if n_padding > 0: + patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) + position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) + + padding = (position_ids == -1).all(dim=-1) + + # Embed, encode, pool + x = self.patch_embedder(patches, position_ids) + freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + freqs = tuple(t.to(x.dtype) for t in freqs) + if n_padding > 0: + mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) + mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min) + else: + mask = None + + for layer in self.encoder.layers: + x = layer(x, freqs, attention_mask=mask) + + if n_padding > 0: + x = x.masked_fill(padding.unsqueeze(-1), 0.0) + + # Average pool by spatial position + clamped = position_ids.clamp(min=0) + max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1 + ki = torch.div(clamped, k, rounding_mode="floor") + ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1] + weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k) + x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) + + # Strip empty output tokens + valid_out = ~((weights == 0).all(dim=1)) + if valid_out.any() and not valid_out.all(): + x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0) + + return x * self.root_hidden_size + + +class Gemma4RMSNormProjector(nn.Module): + """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" + def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None): + super().__init__() + self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.embedding_projection(rms_norm(x)) + + +class Gemma4MultiModalProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) + + +# Audio Encoder + +class Gemma4AudioConvSubsampler(nn.Module): + """2D convolution subsampling for audio features""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + eps = config["rms_norm_eps"] + self.layer0 = nn.ModuleDict({ + 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + self.layer1 = nn.ModuleDict({ + 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + # proj_input_dim = (128 // 4) * 32 = 1024 + self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + + def _conv_layer(self, x, layer, mask): + if mask is not None: + x = x * mask[:, None, :, None].to(x.device) + x = layer['conv'](x.to(layer['conv'].weight.dtype)) + x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + if mask is not None: + mask = mask[:, ::2] + return x, mask + + def forward(self, x, mask=None): + x = x.unsqueeze(1) + x, mask = self._conv_layer(x, self.layer0, mask) + x, mask = self._conv_layer(x, self.layer1, mask) + batch_size, _, seq_len, _ = x.shape + x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(x), mask + + +class Gemma4AudioFeedForward(nn.Module): + """Conformer feed-forward with residual scaling.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config.get("intermediate_size", hidden_size * 4) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.post_layer_scale = config.get("residual_weight", 0.5) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = torch.nn.functional.silu(self.ffw_layer_1(x)) + x = self.ffw_layer_2(x) + x = self.post_layer_norm(x) + x = x * self.post_layer_scale + return x + residual + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for audio attention.""" + def __init__(self, config, device=None, dtype=None): + super().__init__() + hidden_size = config["hidden_size"] + context_left = config.get("attention_context_left", 13) + context_right = config.get("attention_context_right", 0) + self.chunk_size = config.get("attention_chunk_size", 12) + self.context_size = self.chunk_size + context_left - 1 + context_right + + num_timescales = hidden_size // 2 + log_inc = math.log(10000.0) / max(num_timescales - 1, 1) + inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0) + self.register_buffer("inv_timescales", inv_timescales, persistent=False) + + def forward(self, hidden_states): + positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) + scaled = positions * self.inv_timescales.to(device=hidden_states.device) + return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked block attention with relative position bias and softcap.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = config.get("attention_chunk_size", 12) + self.max_past_horizon = config.get("attention_context_left", 13) - 1 + self.max_future_horizon = config.get("attention_context_right", 0) + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_scale = (self.head_dim ** -0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) + + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) + self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + + def _convert_to_block(self, x): + B, S, H, D = x.shape + num_blocks = (S + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - S + x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad)) + return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() + + def _extract_block_context(self, x): + x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) + x = x.unfold(1, self.context_size, self.chunk_size) + return torch.movedim(x, -1, 2).contiguous() + + def _rel_shift(self, x): + B, H, NB, BS, PL = x.shape + CS = self.context_size + x = torch.nn.functional.pad(x, (0, CS + 1 - PL)) + x = x.view(B, H, NB, BS * (CS + 1)) + x = x[..., :BS * CS] + return x.view(B, H, NB, BS, CS) + + def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None): + """Build 5D boolean blocked attention mask (True=attend, False=mask)""" + q = torch.arange(seq_len, device=device) + dist = q[:, None] - q[None, :] + mask = (dist >= 0) & (dist < self.max_past_horizon) + if self.max_future_horizon > 0: + mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon)) + if audio_mask is not None: + mask = mask & audio_mask[0, None, :].bool() + m = mask[None, None] + # Reshape to blocked 5D matching reference code + p = num_blocks * self.chunk_size - seq_len + m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) + m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) + m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False) + idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :] + return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1)) + + def forward(self, x, position_embeddings=None, attn_mask=None): + B, S, _ = x.shape + + q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim) + + q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale) + k = k * self.k_scale + + q_blocks = self._convert_to_block(q) + k_context = self._extract_block_context(k) + v_context = self._extract_block_context(v) + num_blocks = q_blocks.shape[1] + + rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype) + + queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D] + matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ rel_k.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap + + # Mask out invalid positions in chunk context (matching reference's masked_fill approach) + if attn_mask is None: + attn_mask = self._build_blocked_mask(S, num_blocks, x.device) + attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype) + out = attn_weights @ v_context.permute(0, 3, 1, 2, 4) + out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1) + out = out[:, :S].contiguous() + return self.post(out.to(self.post.linear.weight.dtype)) + + +class Gemma4AudioLConv1d(nn.Module): + """Lightweight convolution with standard GLU.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + conv_kernel_size = config.get("conv_kernel_size", 5) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) + # Causal conv: left-pad only + self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = self.linear_start(x) + x = torch.nn.functional.glu(x, dim=-1) + x = x.transpose(1, 2) + x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) + x = self.depthwise_conv1d(x).transpose(1, 2) + x = self.conv_norm(x) + x = torch.nn.functional.silu(x) + x = self.linear_end(x) + return x + residual + + +class Gemma4AudioLayer(nn.Module): + """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden_size = config["hidden_size"] + self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) + self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.norm_out = RMSNorm(hidden_size, **norm_kwargs) + + def forward(self, x, position_embeddings=None, attn_mask=None): + x = self.feed_forward1(x) + + residual = x + x = self.norm_pre_attn(x) + x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + x = self.norm_post_attn(x) + x = x + residual + + x = self.lconv1d(x) + x = self.feed_forward2(x) + + x = self.norm_out(x) + return x + + +class Gemma4AudioEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.output_proj_dims = config.get("output_proj_dims", 1536) + + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) + + self.layers = nn.ModuleList([ + Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + + def forward(self, audio_features, audio_mask=None): + x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) + position_embeddings = self.rel_pos_enc(x) + + # Build blocked attention mask once for all layers + attn_mask = self.layers[0].self_attn._build_blocked_mask( + x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size, + x.device, audio_mask=audio_mask) + + for layer in self.layers: + x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + + x = self.output_proj(x) + return x + + +class Gemma4AudioProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops) + + +# Tokenizer and Wrappers + +class Gemma4_Tokenizer(): + tokenizer_json_data = None + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + def _extract_mel_spectrogram(self, waveform, sample_rate): + """Extract 128-bin log mel spectrogram. + Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. + """ + # Mix to mono first, then resample to 16kHz + if waveform.dim() > 1 and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + audio = waveform.squeeze(0).float().numpy() + if sample_rate != 16000: + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) + from scipy.signal import resample_poly, firwin + from math import gcd + g = gcd(sample_rate, 16000) + up, down = 16000 // g, sample_rate // g + L = max(up, down) + h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) + audio = resample_poly(audio, up, down, window=h).astype(np.float32) + n = len(audio) + + # Pad to multiple of 128, build sample-level mask + if n % 128 != 0: + audio = np.pad(audio, (0, 128 - n % 128)) + mask_raw = np.ones(len(audio), dtype=np.float32) + mask_raw[n:] = 0.0 + + # Semicausal padding: 160 zeros prepended + audio = np.pad(audio, (160, 0)) + mask_raw = np.pad(mask_raw, (160, 0)) + + # Extract 321-sample frames via stride tricks, drop last → 320 + nf = (len(audio) - 321) // 160 + 1 + strides = (audio.strides[0] * 160, audio.strides[0]) + frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy() + + # Periodic Hann window, FFT magnitude, mel filterbank, log + window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32) + magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1)) + mel_fb = self._build_mel_filterbank() + log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32) + + # Frame mask: valid when last sample in window is real audio + mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool) + log_mel = log_mel * mask[:, None] + return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T] + + @staticmethod + def _build_mel_filterbank(): + """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz.""" + mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130) + filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0) + fft_freqs = np.linspace(0, 16000 // 2, 257) + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs): + + # Process audio + audio_features = [] + if audio is not None: + waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio + sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 + mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) + audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) + + # Process image/video frames + is_video = video is not None + source = video if is_video else image + images = [] + if source is not None: + samples = source.movedim(-1, 1) # [B, C, H, W] + num_frames = samples.shape[0] + + # Subsample video to 1fps + if is_video: + fps = kwargs.get("fps", 24) + step = max(1, round(fps)) + indices = list(range(0, num_frames, step)) + if len(indices) == 0: + indices = [0] + samples = samples[indices] + num_frames = len(indices) + + h, w = samples.shape[2], samples.shape[3] + patch_size = 16 + pooling_k = 3 + max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame + max_patches = max_soft_tokens * pooling_k * pooling_k + target_px = max_patches * patch_size * patch_size + factor = (target_px / (h * w)) ** 0.5 + side_mult = pooling_k * patch_size + target_h = max(int(factor * h // side_mult) * side_mult, side_mult) + target_w = max(int(factor * w // side_mult) * side_mult, side_mult) + + import torchvision.transforms.functional as TVF + for i in range(num_frames): + # rescaling to match reference code + s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 + if target_h != h or target_w != w: + s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) + s = s.float() * (1.0 / 255.0) + images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens}) + + if text.startswith('<|turn>'): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is not None: + llama_text = llama_template.format(text) + else: + # Build template from modalities present + system = "<|turn>system\n<|think|>\n" if thinking else "" + media = "" + if len(images) > 0: + if is_video: + media += "\n\n" + for i in range(len(images)): + ts = f"{int(i // 60):02d}:{int(i % 60):02d}" + sep = "" if i == 0 else " " + media += f"{sep}{ts} <|image><|video|>" + media += "\n\n" + else: + media += "\n\n" + for i in range(len(images)): + if i > 0: + media += "\n\n\n\n" + media += "<|image><|image|>" + media += "\n\n" + if len(audio_features) > 0: + # Compute audio token count (always at 16kHz) + num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] + _fl = 320 # int(round(16000 * 20.0 / 1000.0)) + _hl = 160 # int(round(16000 * 10.0 / 1000.0)) + _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 + _t = _nmel + for _ in range(2): + _t = (_t + 2 - 3) // 2 + 1 + n_audio_tokens = min(_t, 750) + media += "<|audio>" + "<|audio|>" * n_audio_tokens + "" + llama_text = f"{system}<|turn>user\n{media}{text}\n<|turn>model\n" + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + def _replace_placeholders(token_list, token_id, embeds): + """Replace first placeholder with embed dict, remove remaining consecutive ones.""" + embed_idx = 0 + i = 0 + while i < len(token_list): + if token_list[i][0] == token_id and embed_idx < len(embeds): + token_list[i] = (embeds[embed_idx],) + token_list[i][1:] + embed_idx += 1 + i += 1 + while i < len(token_list) and token_list[i][0] == token_id: + token_list.pop(i) + else: + i += 1 + + if len(images) > 0: + img_token_id = 258884 if is_video else 258880 + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, img_token_id, img_embeds) + + if len(audio_features) > 0: + aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features] + for r in text_tokens: + _replace_placeholders(r, 258881, aud_embeds) + + return text_tokens + + +class _Gemma4Tokenizer: + """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +# Tokenizer +class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): + embedding_size = 2560 + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + + def decode(self, token_ids, **kwargs): + text = super().decode(token_ids, skip_special_tokens=False) + # Translate thinking channel markers to standard / tags + text = text.replace("<|channel>thought\n", "\n") + text = text.replace("", "") + # Strip remaining special tokens + text = text.replace("", "").replace("", "").strip() + return text + + +class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): + tokenizer_class = Gemma4SDTokenizer + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) + + +# Model wrappers +class Gemma4Model(sd1_clip.SDClipModel): + model_class = None + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + self.dtypes = set() + self.dtypes.add(dtype) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, _ = super().process_tokens(tokens, device) + return embeds + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) + seq_len = embeds.shape[1] + ids = [0] * seq_len + expanded_idx = 0 + embed_map = {info["index"]: info["size"] for info in embeds_info} + for t in tokens_only[0]: + if expanded_idx in embed_map: + expanded_idx += embed_map[expanded_idx] + elif isinstance(t, int): + if expanded_idx < seq_len: + ids[expanded_idx] = t + expanded_idx += 1 + else: + expanded_idx += 1 + initial_token_ids = [ids] + input_ids = torch.tensor(initial_token_ids, device=self.execution_device) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + + +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): + clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class}) + class Gemma4TEModel_(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name="gemma4", clip_model=clip_model, model_options=model_options) + return Gemma4TEModel_ + + +# Variants + +def _make_variant(config_cls): + audio = config_cls.audio_config is not None + bases = (Gemma4AudioMixin, Gemma4Base) if audio else (Gemma4Base,) + class Variant(*bases): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(config_cls(**config_dict), dtype, device, operations) + if audio: + self._init_audio(self.model.config, dtype, device, operations) + embedding_size = config_cls.hidden_size + if embedding_size != Gemma4SDTokenizer.embedding_size: + tok_cls = type('T', (Gemma4SDTokenizer,), {'embedding_size': embedding_size}) + class Tokenizer(Gemma4Tokenizer): + tokenizer_class = tok_cls + Variant.tokenizer = Tokenizer + else: + Variant.tokenizer = Gemma4Tokenizer + return Variant + +Gemma4_E4B = _make_variant(Gemma4Config) +Gemma4_E2B = _make_variant(Gemma4_E2B_Config) +Gemma4_31B = _make_variant(Gemma4_31B_Config) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 6ea8e36b1..a34c41144 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -521,7 +521,7 @@ class Attention(nn.Module): else: present_key_value = (xk, xv, index + num_tokens) - if sliding_window is not None and xk.shape[2] > sliding_window: + if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1: xk = xk[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:] attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None @@ -533,12 +533,12 @@ class Attention(nn.Module): return self.o_proj(output), present_key_value class MLP(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None): super().__init__() - ops = ops or nn - self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) + intermediate_size = intermediate_size or config.intermediate_size + self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) if config.mlp_activation == "silu": self.activation = torch.nn.functional.silu elif config.mlp_activation == "gelu_pytorch_tanh": @@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value +def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype): + class ScaledEmbedding(ops.Embedding): + def forward(self, input_ids, out_dtype=None): + return super().forward(input_ids, out_dtype=out_dtype) * scale + return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype) + + class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = ops.Embedding( - config.vocab_size, - config.hidden_size, - device=device, - dtype=dtype - ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.normalize_in = True + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) else: transformer = TransformerBlock - self.normalize_in = False + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops) @@ -690,15 +691,12 @@ class Llama2_(nn.Module): self.config.rope_dims, device=device) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): if embeds is not None: x = embeds else: x = self.embed_tokens(x, out_dtype=dtype) - if self.normalize_in: - x *= self.config.hidden_size ** 0.5 - seq_len = x.shape[1] past_len = 0 if past_key_values is not None and len(past_key_values) > 0: @@ -850,7 +848,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): device = embeds.device if stop_tokens is None: @@ -875,14 +873,16 @@ class BaseGenerate: pbar = comfy.utils.ProgressBar(max_length) # Generation loop + current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() generated_token_ids.append(token_id) embeds = self.model.embed_tokens(next_token).to(execution_dtype) + current_input_ids = next_token if initial_input_ids is not None else None pbar.update(1) if token_id in stop_tokens: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5aee1f4c0..bc5cbae28 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): tokens_only = [[t[0] for t in b] for b in tokens] - embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is class DualLinearProjection(torch.nn.Module): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 01ebdfabe..b1f1dbb9f 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def process_tokens(self, tokens, device): - embeds, _, _, embeds_info = super().process_tokens(tokens, device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = super().process_tokens(tokens, device) return embeds class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..d8ed9cd32 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_): nn.Module.__init__(self) self.config = config self.vocab_size = config.vocab_size - self.normalize_in = False - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) diff --git a/comfy/utils.py b/comfy/utils.py index 78c491b98..7b7faad3a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res -def normalize_image_embeddings(embeds, embeds_info, scale_factor): - """Normalize image embeddings to match text embedding scale""" - for info in embeds_info: - if info.get("type") == "image": - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 1f46d820f..1661a1011 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), + io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."), + io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), @@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=True) + generated_text = clip.decode(generated_ids) + return io.NodeOutput(generated_text) @@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio) class TextgenExtension(ComfyExtension): From f6d5068ac0163e7f626c9cec2e7c663cf6fa64a8 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 12:20:17 +0800 Subject: [PATCH 041/102] Update README (#13679) Updated the README to include a new screenshot, improved description and add Ernie Image to supported models. --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3b5114633..ee68e8bb8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
# ComfyUI -**The most powerful and modular visual AI engine and application.** +**The most powerful and modular AI engine for content creation.** [![Website][website-shield]][website-url] @@ -31,10 +31,15 @@ [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases -![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) +ComfyUI Screenshot
-ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. +ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... +- ComfyUI natively supports the latest open-source state of the art models. +- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. +- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. +- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. +- It integrates seamlessly into production pipelines with our API endpoints. ## Get Started @@ -77,6 +82,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) + - Ernie Image - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From b5bb83c964519b7574ce9229b2314e04c17592c0 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 18:17:08 +0800 Subject: [PATCH 042/102] Fix issue blend images with alpha (#13615) Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts --- node_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/node_helpers.py b/node_helpers.py index d3d834516..cac4e88dd 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -86,6 +86,6 @@ def image_alpha_fix(destination, source): if destination.shape[-1] < source.shape[-1]: source = source[...,:destination.shape[-1]] elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = 1.0 + source = torch.nn.functional.pad(source, (0, 1)) + source[..., -1] = 1.0 return destination, source From d0f0b15cf5d1fbff67390c8d90ec8654c2582f7a Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Sun, 3 May 2026 18:48:58 +0800 Subject: [PATCH 043/102] Update ComfyUI screenshot in README (#13683) Update ComfyUI screenshot to showcase a more modern workflow --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ee68e8bb8..a3bd3ba0a 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases -ComfyUI Screenshot +ComfyUI Screenshot +
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... From 867b8d2408a8f3062f25bd6707a4b96755d70e1d Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Sun, 3 May 2026 05:44:20 -0700 Subject: [PATCH 044/102] fix: gracefully handle port-in-use error on server startup (#13001) Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback. --- server.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 881da8e66..2f3b438bb 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,4 @@ +import errno import os import sys import asyncio @@ -1245,7 +1246,13 @@ class PromptServer(): address = addr[0] port = addr[1] site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) - await site.start() + try: + await site.start() + except OSError as e: + if e.errno == errno.EADDRINUSE: + logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.") + raise SystemExit(1) + raise if not hasattr(self, 'address'): self.address = address #TODO: remove this From 025e6792ee64181ddce8a84411e0c7311e00b179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 3 May 2026 16:30:00 +0300 Subject: [PATCH 045/102] Batch broadcasting in JoinImageWithAlpha node (#13686) * Batch broadcasting in JoinImageWithAlpha node --- comfy_extras/nodes_compositing.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 3bc9fccb3..5b4423734 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode): @classmethod def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: - batch_size = min(len(image), len(alpha)) - out_images = [] - + batch_size = max(len(image), len(alpha)) alpha = 1.0 - resize_mask(alpha, image.shape[1:]) - for i in range(batch_size): - out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - - return io.NodeOutput(torch.stack(out_images)) + alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size) + image = comfy.utils.repeat_to_batch_size(image, batch_size) + return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1)) class CompositingExtension(ComfyExtension): From b138133ffa43541c85b5f9ca57f449c8345ca005 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sun, 3 May 2026 20:07:21 +0200 Subject: [PATCH 046/102] Enable triton comfy kitchen via cli-arg (#12730) --- comfy/cli_args.py | 1 + comfy/quant_ops.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cef1a5e6b..d2fde8b67 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE" parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") +parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb2..b90bcfd25 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,8 @@ import torch import logging +from comfy.cli_args import args + try: import comfy_kitchen as ck from comfy_kitchen.tensor import ( @@ -21,7 +23,15 @@ try: ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("triton") + if args.enable_triton_backend: + try: + import triton + logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__) + except ImportError as e: + logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.") + ck.registry.disable("triton") + else: + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") except ImportError as e: From cea8d0925febb4dd32e400bbbf94243f55af3371 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 3 May 2026 13:18:27 -0700 Subject: [PATCH 047/102] Refactor LoadImageMask to use LoadImage code. (#13687) --- nodes.py | 66 +++++++++++++++++++++++++------------------------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/nodes.py b/nodes.py index 710cccffe..8f8f90cf6 100644 --- a/nodes.py +++ b/nodes.py @@ -1754,57 +1754,49 @@ class LoadImage: return True -class LoadImageMask: + +class LoadImageMask(LoadImage): ESSENTIALS_CATEGORY = "Image Tools" SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } - } + types = super().INPUT_TYPES() + return { + "required": { + **types["required"], + "channel": (s._color_channels, ) + } + } CATEGORY = "mask" - RETURN_TYPES = ("MASK",) - FUNCTION = "load_image" - def load_image(self, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - i = node_helpers.pillow(Image.open, image_path) - i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.getbands() != ("R", "G", "B", "A"): - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - i = i.convert("RGBA") - mask = None + FUNCTION = "load_image_mask" + + def load_image_mask(self, image, channel): + image_tensor, mask_tensor = super().load_image(image) c = channel[0].upper() - if c in i.getbands(): - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask) - if c == 'A': - mask = 1. - mask + + if c == 'A': + return (mask_tensor,) + + channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) + + if channel_idx < image_tensor.shape[-1]: + return (image_tensor[..., channel_idx].clone(),) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (mask.unsqueeze(0),) + empty_mask = torch.zeros( + image_tensor.shape[:-1], + dtype=image_tensor.dtype, + device=image_tensor.device + ) + return (empty_mask,) @classmethod def IS_CHANGED(s, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - m = hashlib.sha256() - with open(image_path, 'rb') as f: - m.update(f.read()) - return m.digest().hex() - - @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - - return True + return super().IS_CHANGED(image) class LoadImageOutput(LoadImage): From 2806163f6e06465bacb1b16906cd17a8b78c9610 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 3 May 2026 16:21:34 -0700 Subject: [PATCH 048/102] Default control_after_generate to fixed in PrimitiveInt node (#13690) --- comfy_extras/nodes_primitive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 9c2e98758..3c8f90b19 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -49,7 +49,7 @@ class Int(io.ComfyNode): display_name="Int", category="utils/primitive", inputs=[ - io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), ], outputs=[io.Int.Output()], ) From 5538f62b0b81102c382849fd90469283c725b212 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Mon, 4 May 2026 12:33:11 +0800 Subject: [PATCH 049/102] fix: Update ColorTransfer node ref_image to be mandatory (#13691) --- comfy_extras/nodes_post_processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..345fdb695 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ColorTransfer", + display_name="Color Transfer", category="image/postprocessing", description="Match the colors of one image to another using various algorithms.", search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], inputs=[ io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), - io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.DynamicCombo.Input("source_stats", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", From f3ea976cba8743a87efeb9fbca717309e3d65c47 Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Mon, 4 May 2026 10:01:46 +0200 Subject: [PATCH 050/102] Fix a1111 typo in extra_model_paths.yaml (#2720) --- extra_model_paths.yaml.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 34df01681..9c395c0b2 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -28,7 +28,7 @@ #config for a1111 ui #all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed -#a111: +#a1111: # base_path: path/to/stable-diffusion-webui/ # checkpoints: models/Stable-diffusion # configs: models/Stable-diffusion From c33d26c283ea53b8ba3e42ef3dca1f03ddf4d7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 4 May 2026 20:20:40 +0300 Subject: [PATCH 051/102] fix: Proper memory estimation for frame interpolation when not using dynamic VRAM (#13698) --- comfy_extras/frame_interpolation_models/film_net.py | 3 +++ comfy_extras/frame_interpolation_models/ifnet.py | 3 +++ comfy_extras/nodes_frame_interpolation.py | 11 ++++------- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py index cf4f6e1e1..36bc79dc3 100644 --- a/comfy_extras/frame_interpolation_models/film_net.py +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -199,6 +199,9 @@ class FILMNet(nn.Module): def get_dtype(self): return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + def memory_used_forward(self, shape, dtype): + return 1700 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): """Pre-compute warp grids for all pyramid levels.""" if (H, W) in self._warp_grids: diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py index 03cb34c50..ad6edbec9 100644 --- a/comfy_extras/frame_interpolation_models/ifnet.py +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -74,6 +74,9 @@ class IFNet(nn.Module): def get_dtype(self): return self.encode.cnn0.weight.dtype + def memory_used_forward(self, shape, dtype): + return 300 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): if (H, W) in self._warp_grids: return diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index a3b00d36e..fa49c203a 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode): model = cls._detect_and_load(sd) dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) - patcher = comfy.model_patcher.ModelPatcher( + patcher = comfy.model_patcher.CoreModelPatcher( model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), @@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode): if num_frames < 2 or multiplier < 2: return io.NodeOutput(images) - model_management.load_model_gpu(interp_model) device = interp_model.load_device dtype = interp_model.model_dtype() inference_model = interp_model.model - - # Free VRAM for inference activations (model weights + ~20x a single frame's worth) - H, W = images.shape[1], images.shape[2] - activation_mem = H * W * 3 * images.element_size() * 20 - model_management.free_memory(activation_mem, device) + activation_mem = inference_model.memory_used_forward(images.shape, dtype) + model_management.load_models_gpu([interp_model], memory_required=activation_mem) align = getattr(inference_model, "pad_align", 1) + H, W = images.shape[1], images.shape[2] # Prepare a single padded frame on device for determining output dimensions def prepare_frame(idx): From c47633f3befbc32bf5aeece6a899c20d55a9feb1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 5 May 2026 05:56:05 +1000 Subject: [PATCH 052/102] prefetch: guard against no offload (#13703) cast_ will return no stream if there is no work to do. guard against this is the consume logic. --- comfy/model_prefetch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index 0ad35deb5..72e11dec6 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -37,7 +37,8 @@ def prefetch_queue_pop(queue, device, module): consumed = queue.pop(0) if consumed is not None: offload_stream, prefetch_state = consumed - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) _, comfy_modules = prefetch_state if comfy_modules is not None: cleanup_prefetched_modules(comfy_modules) From 1ac78180b3f797f09f5805e5a923debe77638889 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 5 May 2026 05:58:06 +1000 Subject: [PATCH 053/102] make control-net load order deterministic (#13701) Make this deterministic so speeds dont change base of load order. Load them in reverse order so whatever the caller lists first is the top priority. --- comfy/model_management.py | 10 ++++++---- comfy/sampler_helpers.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 02ad66656..21738a4c7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -721,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models_temp = set() + # Order-preserving dedup. A plain set() would randomize iteration order across runs + models_temp = {} for m in models: - models_temp.add(m) + models_temp[m] = None for mm in m.model_patches_models(): - models_temp.add(mm) + models_temp[mm] = None - models = models_temp + models = list(models_temp) + models.reverse() models_to_load = [] diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index bbba09e26..3782fd2d5 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -89,7 +89,8 @@ def get_additional_models(conds, dtype): gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - control_nets = set(cnets) + # Order-preserving dedup. A plain set() would randomize iteration order across runs + control_nets = list(dict.fromkeys(cnets)) inference_memory = 0 control_models = [] From 1265955b34dd63622caecb71fdae550fe5cb44fb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 5 May 2026 09:40:57 +1000 Subject: [PATCH 054/102] ops: handle multi-compute of the same weight (#13705) If the same weight is used multiple times within the same prefetch window, it should only apply compute state mutations once. Mark the weight as fully resident on the first pass accordingly. --- comfy/ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 4f0338346..585c185a3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -253,6 +253,9 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w if bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) + if prefetch["signature"] is not None: + prefetch["resident"] = True + return weight, bias From 15a4494a4e5299e1210ccad6d49d3253555ef3e6 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Tue, 5 May 2026 08:37:25 +0800 Subject: [PATCH 055/102] chore: Update display names and categories (CORE-151) (#13693) * Standardize DEPRECATED label in node display name * Promote category image/video to root level video/ * Update images and masks names and categories --- comfy_api_nodes/nodes_sora.py | 2 +- comfy_extras/nodes_frame_interpolation.py | 2 +- comfy_extras/nodes_image_compare.py | 2 +- comfy_extras/nodes_images.py | 17 ++++++++++----- comfy_extras/nodes_mask.py | 9 ++++++-- comfy_extras/nodes_morphology.py | 6 ++++-- comfy_extras/nodes_post_processing.py | 6 ++++-- comfy_extras/nodes_video.py | 13 ++++++------ nodes.py | 25 ++++++++++++----------- 9 files changed, 50 insertions(+), 32 deletions(-) diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 4d9075dcf..c1d485188 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -33,7 +33,7 @@ class OpenAIVideoSora2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIVideoSora2", - display_name="OpenAI Sora - Video (Deprecated)", + display_name="OpenAI Sora - Video (DEPRECATED)", category="api node/video/Sora", description=( "OpenAI video and audio generation.\n\n" diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index fa49c203a..9dd34cfb8 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -78,7 +78,7 @@ class FrameInterpolate(io.ComfyNode): return io.Schema( node_id="FrameInterpolate", display_name="Frame Interpolate", - category="image/video", + category="video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ FrameInterpolationModel.Input("interp_model"), diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py index 3d943be67..58af9ae82 100644 --- a/comfy_extras/nodes_image_compare.py +++ b/comfy_extras/nodes_image_compare.py @@ -11,7 +11,7 @@ class ImageCompare(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompare", - display_name="Image Compare", + display_name="Compare Images", description="Compares two images side by side with a slider.", category="image", essentials_category="Image Tools", diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index a77f0641f..68916ab75 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -24,7 +24,7 @@ class ImageCrop(IO.ComfyNode): return IO.Schema( node_id="ImageCrop", search_aliases=["trim"], - display_name="Image Crop (Deprecated)", + display_name="Crop Image (DEPRECATED)", category="image/transform", is_deprecated=True, essentials_category="Image Tools", @@ -56,7 +56,7 @@ class ImageCropV2(IO.ComfyNode): return IO.Schema( node_id="ImageCropV2", search_aliases=["trim"], - display_name="Image Crop", + display_name="Crop Image", category="image/transform", essentials_category="Image Tools", has_intermediate_output=True, @@ -109,6 +109,7 @@ class RepeatImageBatch(IO.ComfyNode): return IO.Schema( node_id="RepeatImageBatch", search_aliases=["duplicate image", "clone image"], + display_name="Repeat Image Batch", category="image/batch", inputs=[ IO.Image.Input("image"), @@ -131,6 +132,7 @@ class ImageFromBatch(IO.ComfyNode): return IO.Schema( node_id="ImageFromBatch", search_aliases=["select image", "pick from batch", "extract image"], + display_name="Get Image from Batch", category="image/batch", inputs=[ IO.Image.Input("image"), @@ -157,7 +159,8 @@ class ImageAddNoise(IO.ComfyNode): return IO.Schema( node_id="ImageAddNoise", search_aliases=["film grain"], - category="image", + display_name="Add Noise to Image", + category="image/postprocessing", inputs=[ IO.Image.Input("image"), IO.Int.Input( @@ -259,7 +262,7 @@ class ImageStitch(IO.ComfyNode): return IO.Schema( node_id="ImageStitch", search_aliases=["combine images", "join images", "concatenate images", "side by side"], - display_name="Image Stitch", + display_name="Stitch Images", description="Stitches image2 to image1 in the specified direction.\n" "If image2 is not provided, returns image1 unchanged.\n" "Optional spacing can be added between images.", @@ -434,6 +437,7 @@ class ResizeAndPadImage(IO.ComfyNode): return IO.Schema( node_id="ResizeAndPadImage", search_aliases=["fit to size"], + display_name="Resize And Pad Image", category="image/transform", inputs=[ IO.Image.Input("image"), @@ -485,6 +489,7 @@ class SaveSVGNode(IO.ComfyNode): return IO.Schema( node_id="SaveSVGNode", search_aliases=["export vector", "save vector graphics"], + display_name="Save SVG", description="Save SVG files on disk.", category="image/save", inputs=[ @@ -591,7 +596,7 @@ class ImageRotate(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageRotate", - display_name="Image Rotate", + display_name="Rotate Image", search_aliases=["turn", "flip orientation"], category="image/transform", essentials_category="Image Tools", @@ -624,6 +629,7 @@ class ImageFlip(IO.ComfyNode): return IO.Schema( node_id="ImageFlip", search_aliases=["mirror", "reflect"], + display_name="Flip Image", category="image/transform", inputs=[ IO.Image.Input("image"), @@ -650,6 +656,7 @@ class ImageScaleToMaxDimension(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageScaleToMaxDimension", + display_name="Scale Image to Max Dimension", category="image/upscaling", inputs=[ IO.Image.Input("image"), diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 8ca947718..43a933dac 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -80,7 +80,8 @@ class ImageCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompositeMasked", - search_aliases=["paste image", "overlay", "layer"], + search_aliases=["overlay", "layer", "paste image", "images composition"], + display_name="Image Composite Masked", category="image", inputs=[ IO.Image.Input("destination"), @@ -201,6 +202,7 @@ class InvertMask(IO.ComfyNode): return IO.Schema( node_id="InvertMask", search_aliases=["reverse mask", "flip mask"], + display_name="Invert Mask", category="mask", inputs=[ IO.Mask.Input("mask"), @@ -222,6 +224,7 @@ class CropMask(IO.ComfyNode): return IO.Schema( node_id="CropMask", search_aliases=["cut mask", "extract mask region", "mask slice"], + display_name="Crop Mask", category="mask", inputs=[ IO.Mask.Input("mask"), @@ -247,7 +250,8 @@ class MaskComposite(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskComposite", - search_aliases=["combine masks", "blend masks", "layer masks"], + search_aliases=["combine masks", "blend masks", "layer masks", "masks composition"], + display_name="Combine Masks", category="mask", inputs=[ IO.Mask.Input("destination"), @@ -298,6 +302,7 @@ class FeatherMask(IO.ComfyNode): return IO.Schema( node_id="FeatherMask", search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], + display_name="Feather Mask", category="mask", inputs=[ IO.Mask.Input("mask"), diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 4ab2fb7e8..c01b9436d 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -59,7 +59,8 @@ class ImageRGBToYUV(io.ComfyNode): return io.Schema( node_id="ImageRGBToYUV", search_aliases=["color space conversion"], - category="image/batch", + display_name="Image RGB to YUV", + category="image/color", inputs=[ io.Image.Input("image"), ], @@ -81,7 +82,8 @@ class ImageYUVToRGB(io.ComfyNode): return io.Schema( node_id="ImageYUVToRGB", search_aliases=["color space conversion"], - category="image/batch", + display_name="Image YUV to RGB", + category="image/color", inputs=[ io.Image.Input("Y"), io.Image.Input("U"), diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 345fdb695..d938a2035 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -20,7 +20,8 @@ class Blend(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageBlend", - display_name="Image Blend", + search_aliases=["mix images"], + display_name="Blend Images", category="image/postprocessing", essentials_category="Image Tools", inputs=[ @@ -224,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageScaleToTotalPixels", + display_name="Scale Image to Total Pixels", category="image/upscaling", inputs=[ io.Image.Input("image"), @@ -568,7 +570,7 @@ class BatchImagesNode(io.ComfyNode): return io.Schema( node_id="BatchImagesNode", display_name="Batch Images", - category="image", + category="image/batch", essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 5c096c232..719acf2f1 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -17,7 +17,8 @@ class SaveWEBM(io.ComfyNode): return io.Schema( node_id="SaveWEBM", search_aliases=["export webm"], - category="image/video", + display_name="Save WEBM", + category="video", is_experimental=True, inputs=[ io.Image.Input("images"), @@ -72,7 +73,7 @@ class SaveVideo(io.ComfyNode): node_id="SaveVideo", search_aliases=["export video"], display_name="Save Video", - category="image/video", + category="video", essentials_category="Basics", description="Saves the input images to your ComfyUI output directory.", inputs=[ @@ -121,7 +122,7 @@ class CreateVideo(io.ComfyNode): node_id="CreateVideo", search_aliases=["images to video"], display_name="Create Video", - category="image/video", + category="video", description="Create a video from images.", inputs=[ io.Image.Input("images", tooltip="The images to create a video from."), @@ -146,7 +147,7 @@ class GetVideoComponents(io.ComfyNode): node_id="GetVideoComponents", search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", - category="image/video", + category="video", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -174,7 +175,7 @@ class LoadVideo(io.ComfyNode): node_id="LoadVideo", search_aliases=["import video", "open video", "video file"], display_name="Load Video", - category="image/video", + category="video", essentials_category="Basics", inputs=[ io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), @@ -216,7 +217,7 @@ class VideoSlice(io.ComfyNode): "frame load cap", "start time", ], - category="image/video", + category="video", essentials_category="Video Tools", inputs=[ io.Video.Input("video"), diff --git a/nodes.py b/nodes.py index 8f8f90cf6..eebc2fa76 100644 --- a/nodes.py +++ b/nodes.py @@ -1887,7 +1887,7 @@ class ImageInvert: RETURN_TYPES = ("IMAGE",) FUNCTION = "invert" - CATEGORY = "image" + CATEGORY = "image/color" def invert(self, image): s = 1.0 - image @@ -1903,7 +1903,7 @@ class ImageBatch: RETURN_TYPES = ("IMAGE",) FUNCTION = "batch" - CATEGORY = "image" + CATEGORY = "image/batch" DEPRECATED = True def batch(self, image1, image2): @@ -1960,7 +1960,7 @@ class ImagePadForOutpaint: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "expand_image" - CATEGORY = "image" + CATEGORY = "image/transform" def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() @@ -2103,7 +2103,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetMask": "Conditioning (Set Mask)", - "ControlNetApply": "Apply ControlNet (OLD)", + "ControlNetApply": "Apply ControlNet (DEPRECATED)", "ControlNetApplyAdvanced": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", @@ -2121,6 +2121,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image + "EmptyImage": "Empty Image", "SaveImage": "Save Image", "PreviewImage": "Preview Image", "LoadImage": "Load Image", @@ -2128,15 +2129,15 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImageOutput": "Load Image (from Outputs)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", - "ImageInvert": "Invert Image", + "ImageInvert": "Invert Image Colors", "ImagePadForOutpaint": "Pad Image for Outpainting", - "ImageBatch": "Batch Images", - "ImageCrop": "Image Crop", - "ImageStitch": "Image Stitch", - "ImageBlend": "Image Blend", - "ImageBlur": "Image Blur", - "ImageQuantize": "Image Quantize", - "ImageSharpen": "Image Sharpen", + "ImageBatch": "Batch Images (DEPRECATED)", + "ImageCrop": "Crop Image", + "ImageStitch": "Stitch Images", + "ImageBlend": "Blend Images", + "ImageBlur": "Blur Image", + "ImageQuantize": "Quantize Image", + "ImageSharpen": "Sharpen Image", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", "GetImageSize": "Get Image Size", # _for_testing From 35819e35a8f55425ffc23a3e901256cc22bad724 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Mon, 4 May 2026 18:28:21 -0700 Subject: [PATCH 056/102] fix(spec): mark DeviceStats.index and NodeInfo.essentials_category as nullable (#13706) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(spec): mark DeviceStats.index and NodeInfo.essentials_category as nullable Two fields in openapi.yaml are declared as required/non-nullable but the Python implementation legitimately returns `null` for them, so any client that response-validates against the spec will fail. `DeviceStats.index` (used by GET /api/system_stats): - server.py emits `"index": device.index` unconditionally - For the CPU device (--cpu mode), `torch.device("cpu").index` is `None` - → JSON response includes `"index": null` for CPU devices `NodeInfo.essentials_category` (used by GET /api/object_info): - The V3 schema-based path (comfy_api/latest/_io.py:1654) unconditionally passes `essentials_category=self.essentials_category` into NodeInfoV1 and serializes via dataclasses.asdict(), so the key is always present - Schema's `essentials_category` defaults to `None` for nodes that don't set it in `define_schema` (e.g. the APG node) - → JSON response includes `"essentials_category": null` for those nodes - (The V1 path in server.py uses `hasattr` and so omits the key entirely when not set, but the V3 path is the one that produces nulls) Both fields keep their existing `required` status — they're always present in the response, the value is just nullable. Descriptions expanded to spell out when `null` is expected. * docs(spec): clarify essentials_category presence rules The previous description said "null for nodes that don't set ESSENTIALS_CATEGORY (V1)" — that's wrong. server.py:739-740 uses `hasattr` and OMITS the key when the V1 attribute isn't defined; null only happens if the attribute is explicitly set to None. Spell out all three legal shapes (string / null / absent) and which path produces which. --- openapi.yaml | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 77d0e2318..3b602e0f6 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2347,7 +2347,12 @@ components: description: Device type (cuda, mps, cpu, etc.) index: type: number - description: Device index + nullable: true + description: | + Device index within its type (e.g. CUDA ordinal for `cuda:0`, + `cuda:1`). `null` for devices with no index, including the CPU + device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index` + is `None`). vram_total: type: number description: Total VRAM in bytes @@ -2503,7 +2508,18 @@ components: description: Alternative search terms for finding this node essentials_category: type: string - description: Category override used by the essentials pack + nullable: true + description: | + Category override used by the essentials pack. The + `essentials_category` key may be present with a string value, + present and `null`, or absent entirely: + + - V1 nodes: `essentials_category` is **omitted** when the node + class doesn't define an `ESSENTIALS_CATEGORY` attribute, and + **`null`** if the attribute is explicitly set to `None`. + - V3 nodes (`comfy_api.latest.io`): `essentials_category` is + **always present**, and **`null`** for nodes whose `Schema` + doesn't populate it. # ------------------------------------------------------------------- # Models From 413e250ccd04d830c3fa27f8b1957885ea0b8e1b Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Mon, 4 May 2026 18:59:48 -0700 Subject: [PATCH 057/102] spec: add workflow_id / workflow_version_id to PromptRequest with x-runtime tag (#13709) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two optional, nullable UUID fields to PromptRequest for runtimes that wrap workflow execution in a workflow-version entity (the hosted-cloud runtime does this; local ComfyUI does not). Both fields are tagged `x-runtime: [cloud]` to mark them as runtime-specific — local ComfyUI returns `null` (or omits them entirely) and that's correct behavior, not drift. ## Why these fields belong in the OSS spec Hosted-cloud's frontend and backend share `openapi.yaml` as their single source of truth via auto-generated client types. Without the fields declared in the spec, the cloud runtime has to either: 1. Hand-edit a vendored copy of openapi.yaml (drift between vendor and upstream — unsustainable). 2. Maintain a separate cloud-only spec file (forks the contract, defeats the point of a shared OSS spec). Both options have been tried and both produce maintenance pain. The shape that scales is: cloud-only fields live in OSS spec under their intended path, declared nullable, with an explicit `x-runtime` tag so local-only readers can ignore them programmatically and human readers can see what each runtime populates. ## About the `x-runtime` extension This is the first use of `x-runtime` in this spec. Convention: - `x-runtime: [cloud]` — only the hosted-cloud runtime populates the field; local returns null or omits. - `x-runtime: [local]` — only local populates; cloud returns null. - Tag absent — both runtimes populate the field (the default). This is a vendor extension (`x-` prefix) and is ignored by spec validators that don't recognize it, including `kin-openapi`. Local clients reading the spec see two extra optional nullable fields, which is forward-compatible with all existing readers. ## What this does not change - No Python code changes. `PromptRequest` already accepts arbitrary optional fields (`extra_data: additionalProperties: true` on the same schema is a stronger guarantee). The Python server already silently accepts and ignores both fields today. - No required-fields change. Both fields stay outside `required`, so older clients that don't know about them keep validating. - No nullability widening on existing fields. ## Verification - YAML parses (`yaml.safe_load`). - `kin-openapi` `loader.LoadFromFile` accepts the modified spec. - `openapi3filter.ValidateRequest` on a PromptRequest with both fields set to `null`, set to a valid UUID, or omitted — all pass. --- openapi.yaml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 3b602e0f6..30f85b6ad 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1999,6 +1999,26 @@ components: items: type: string description: List of node IDs to execute (partial graph execution) + workflow_id: + type: string + format: uuid + nullable: true + x-runtime: [cloud] + description: | + UUID identifying a hosted-cloud workflow entity to associate with this + job. Local ComfyUI doesn't track workflow entities and returns `null` + (or omits the field). The `x-runtime: [cloud]` extension marks this + as populated only by the hosted-cloud runtime; absence of the tag + means a field is populated by all runtimes. + workflow_version_id: + type: string + format: uuid + nullable: true + x-runtime: [cloud] + description: | + UUID identifying a hosted-cloud workflow version to associate with + this job. Local ComfyUI returns `null` (or omits the field). See + `workflow_id` above for `x-runtime` semantics. PromptResponse: type: object From ae457da84bfaf074d68b73881c565e3dd2b20b98 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 4 May 2026 19:50:26 -0700 Subject: [PATCH 058/102] feat: add generic --feature-flag CLI arg and --list-feature-flags registry (#13685) --- comfy/cli_args.py | 2 + comfy_api/feature_flags.py | 92 +++++++++++++++++++++++++++++++- main.py | 10 +++- tests-unit/feature_flags_test.py | 85 +++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d2fde8b67..9dadb0093 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -238,6 +238,8 @@ database_default_path = os.path.abspath( ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") +parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button") +parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 9f6918315..adb5a3144 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend, allowing graceful protocol evolution while maintaining backward compatibility. """ -from typing import Any +import logging +from typing import Any, TypedDict from comfy.cli_args import args + +class FeatureFlagInfo(TypedDict): + type: str + default: Any + description: str + + +# Registry of known CLI-settable feature flags. +# Launchers can query this via --list-feature-flags to discover valid flags. +CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = { + "show_signin_button": { + "type": "bool", + "default": False, + "description": "Show the sign-in button in the frontend even when not signed in", + }, +} + + +def _coerce_bool(v: str) -> bool: + """Strict bool coercion: only 'true'/'false' (case-insensitive). + + Anything else raises ValueError so the caller can warn and drop the flag, + rather than silently treating typos like 'ture' or 'yes' as False. + """ + lower = v.lower() + if lower == "true": + return True + if lower == "false": + return False + raise ValueError(f"expected 'true' or 'false', got {v!r}") + + +_COERCE_FNS: dict[str, Any] = { + "bool": _coerce_bool, + "int": lambda v: int(v), + "float": lambda v: float(v), +} + + +def _coerce_flag_value(key: str, raw_value: str) -> Any: + """Coerce a raw string value using the registry type, or keep as string. + + Returns the raw string if the key is unregistered or the type is unknown. + Raises ValueError/TypeError if the key is registered with a known type but + the value cannot be coerced; callers are expected to warn and drop the flag. + """ + info = CLI_FEATURE_FLAG_REGISTRY.get(key) + if info is None: + return raw_value + coerce = _COERCE_FNS.get(info["type"]) + if coerce is None: + return raw_value + return coerce(raw_value) + + +def _parse_cli_feature_flags() -> dict[str, Any]: + """Parse --feature-flag key=value pairs from CLI args into a dict. + + Items without '=' default to the value 'true' (bare flag form). + Flags whose value cannot be coerced to the registered type are dropped + with a warning, so a typo like '--feature-flag some_bool=ture' does not + silently take effect as the wrong value. + """ + result: dict[str, Any] = {} + for item in getattr(args, "feature_flag", []): + key, sep, raw_value = item.partition("=") + key = key.strip() + if not key: + continue + if not sep: + raw_value = "true" + try: + result[key] = _coerce_flag_value(key, raw_value.strip()) + except (ValueError, TypeError) as e: + info = CLI_FEATURE_FLAG_REGISTRY.get(key, {}) + logging.warning( + "Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.", + key, raw_value.strip(), info.get("type", "?"), e, + ) + return result + + # Default server capabilities -SERVER_FEATURE_FLAGS: dict[str, Any] = { +_CORE_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, @@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "assets": args.enable_assets, } +# CLI-provided flags cannot overwrite core flags +_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS} + +SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags} + def get_connection_feature( sockets_metadata: dict[str, dict[str, Any]], diff --git a/main.py b/main.py index dbaf2745c..a6fdaf43c 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,21 @@ import comfy.options comfy.options.enable_args_parsing() +from comfy.cli_args import args + +if args.list_feature_flags: + import json + from comfy_api.feature_flags import CLI_FEATURE_FLAG_REGISTRY + print(json.dumps(CLI_FEATURE_FLAG_REGISTRY, indent=2)) # noqa: T201 + raise SystemExit(0) + import os import importlib.util import shutil import importlib.metadata import folder_paths import time -from comfy.cli_args import args, enables_dynamic_vram +from comfy.cli_args import enables_dynamic_vram from app.logger import setup_logger setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) diff --git a/tests-unit/feature_flags_test.py b/tests-unit/feature_flags_test.py index f2702cfc8..8ec52a124 100644 --- a/tests-unit/feature_flags_test.py +++ b/tests-unit/feature_flags_test.py @@ -1,10 +1,15 @@ """Tests for feature flags functionality.""" +import pytest + from comfy_api.feature_flags import ( get_connection_feature, supports_feature, get_server_features, + CLI_FEATURE_FLAG_REGISTRY, SERVER_FEATURE_FLAGS, + _coerce_flag_value, + _parse_cli_feature_flags, ) @@ -96,3 +101,83 @@ class TestFeatureFlags: result = get_connection_feature(sockets_metadata, "sid1", "any_feature") assert result is False assert supports_feature(sockets_metadata, "sid1", "any_feature") is False + + +class TestCoerceFlagValue: + """Test suite for _coerce_flag_value.""" + + def test_registered_bool_true(self): + assert _coerce_flag_value("show_signin_button", "true") is True + assert _coerce_flag_value("show_signin_button", "True") is True + + def test_registered_bool_false(self): + assert _coerce_flag_value("show_signin_button", "false") is False + assert _coerce_flag_value("show_signin_button", "FALSE") is False + + def test_unregistered_key_stays_string(self): + assert _coerce_flag_value("unknown_flag", "true") == "true" + assert _coerce_flag_value("unknown_flag", "42") == "42" + + def test_bool_typo_raises(self): + """Strict bool: typos like 'ture' or 'yes' must raise so the flag can be dropped.""" + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "ture") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "yes") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "1") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "") + + def test_failed_int_coercion_raises(self, monkeypatch): + """Malformed values for typed flags must raise; caller decides what to do.""" + monkeypatch.setitem( + CLI_FEATURE_FLAG_REGISTRY, + "test_int_flag", + {"type": "int", "default": 0, "description": "test"}, + ) + with pytest.raises(ValueError): + _coerce_flag_value("test_int_flag", "not_a_number") + + +class TestParseCliFeatureFlags: + """Test suite for _parse_cli_feature_flags.""" + + def test_single_flag(self, monkeypatch): + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button=true"]})()) + result = _parse_cli_feature_flags() + assert result == {"show_signin_button": True} + + def test_missing_equals_defaults_to_true(self, monkeypatch): + """Bare flag without '=' is treated as the string 'true' (and coerced if registered).""" + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button", "valid=1"]})()) + result = _parse_cli_feature_flags() + assert result == {"show_signin_button": True, "valid": "1"} + + def test_empty_key_skipped(self, monkeypatch): + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["=value", "valid=1"]})()) + result = _parse_cli_feature_flags() + assert result == {"valid": "1"} + + def test_invalid_bool_value_dropped(self, monkeypatch, caplog): + """A typo'd bool value must be dropped entirely, not silently set to False + and not stored as a raw string. A warning must be logged.""" + monkeypatch.setattr( + "comfy_api.feature_flags.args", + type("Args", (), {"feature_flag": ["show_signin_button=ture", "valid=1"]})(), + ) + with caplog.at_level("WARNING"): + result = _parse_cli_feature_flags() + assert result == {"valid": "1"} + assert "show_signin_button" not in result + assert any("show_signin_button" in r.message and "drop" in r.message.lower() for r in caplog.records) + + +class TestCliFeatureFlagRegistry: + """Test suite for the CLI feature flag registry.""" + + def test_registry_entries_have_required_fields(self): + for key, info in CLI_FEATURE_FLAG_REGISTRY.items(): + assert "type" in info, f"{key} missing 'type'" + assert "default" in info, f"{key} missing 'default'" + assert "description" in info, f"{key} missing 'description'" From e758594e3b93a0851018347a59bb0fd35f54205a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 4 May 2026 20:17:56 -0700 Subject: [PATCH 059/102] Add deploy environment header (Comfy-Env) to partner node API calls (#13425) --- .gitignore | 1 + comfy/deploy_environment.py | 34 ++++++++ comfy_api_nodes/util/client.py | 3 + tests-unit/deploy_environment_test.py | 109 ++++++++++++++++++++++++++ 4 files changed, 147 insertions(+) create mode 100644 comfy/deploy_environment.py create mode 100644 tests-unit/deploy_environment_test.py diff --git a/.gitignore b/.gitignore index 0ab4ba75e..fc426eda4 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ web_custom_versions/ .DS_Store filtered-openapi.yaml uv.lock +.comfy_environment diff --git a/comfy/deploy_environment.py b/comfy/deploy_environment.py new file mode 100644 index 000000000..8c99a3584 --- /dev/null +++ b/comfy/deploy_environment.py @@ -0,0 +1,34 @@ +import functools +import logging +import os + +logger = logging.getLogger(__name__) + +_DEFAULT_DEPLOY_ENV = "local-git" +_ENV_FILENAME = ".comfy_environment" + +# Resolve the ComfyUI install directory (the parent of this `comfy/` package). +# We deliberately avoid `folder_paths.base_path` here because that is overridden +# by the `--base-directory` CLI arg to a user-supplied path, whereas the +# `.comfy_environment` marker is written by launchers/installers next to the +# ComfyUI install itself. +_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + +@functools.cache +def get_deploy_environment() -> str: + env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME) + try: + with open(env_file, encoding="utf-8") as f: + # Cap the read so a malformed or maliciously crafted file (e.g. + # a single huge line with no newline) can't blow up memory. + first_line = f.readline(128).strip() + value = "".join(c for c in first_line if 32 <= ord(c) < 127) + if value: + return value + except FileNotFoundError: + pass + except Exception as e: + logger.error("Failed to read %s: %s", env_file, e) + + return _DEFAULT_DEPLOY_ENV diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index a0b8d35e1..8e1ba91ba 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -19,6 +19,8 @@ from comfy import utils from comfy_api.latest import IO from server import PromptServer +from comfy.deploy_environment import get_deploy_environment + from . import request_logger from ._helpers import ( default_base_url, @@ -624,6 +626,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) + payload_headers["Comfy-Env"] = get_deploy_environment() if cfg.endpoint.headers: payload_headers.update(cfg.endpoint.headers) diff --git a/tests-unit/deploy_environment_test.py b/tests-unit/deploy_environment_test.py new file mode 100644 index 000000000..c3497fbb0 --- /dev/null +++ b/tests-unit/deploy_environment_test.py @@ -0,0 +1,109 @@ +"""Tests for comfy.deploy_environment.""" + +import os + +import pytest + +from comfy import deploy_environment +from comfy.deploy_environment import get_deploy_environment + + +@pytest.fixture(autouse=True) +def _reset_cache_and_install_dir(tmp_path, monkeypatch): + """Reset the functools cache and point the ComfyUI install dir at a tmp dir for each test.""" + get_deploy_environment.cache_clear() + monkeypatch.setattr(deploy_environment, "_COMFY_INSTALL_DIR", str(tmp_path)) + yield + get_deploy_environment.cache_clear() + + +def _write_env_file(tmp_path, content: str) -> str: + """Write the env file with exact content (no newline translation). + + `newline=""` disables Python's text-mode newline translation so the bytes + on disk match the literal string passed in, regardless of host OS. + Newline-style tests (CRLF, lone CR) rely on this. + """ + path = os.path.join(str(tmp_path), ".comfy_environment") + with open(path, "w", encoding="utf-8", newline="") as f: + f.write(content) + return path + + +class TestGetDeployEnvironment: + def test_returns_local_git_when_file_missing(self): + assert get_deploy_environment() == "local-git" + + def test_reads_value_from_file(self, tmp_path): + _write_env_file(tmp_path, "local-desktop2-standalone\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_strips_trailing_whitespace_and_newline(self, tmp_path): + _write_env_file(tmp_path, " local-desktop2-standalone \n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_only_first_line_is_used(self, tmp_path): + _write_env_file(tmp_path, "first-line\nsecond-line\n") + assert get_deploy_environment() == "first-line" + + def test_crlf_line_ending(self, tmp_path): + # Windows editors often save text files with CRLF line endings. + # The CR must not end up in the returned value. + _write_env_file(tmp_path, "local-desktop2-standalone\r\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_crlf_multiline_only_first_line_used(self, tmp_path): + _write_env_file(tmp_path, "first-line\r\nsecond-line\r\n") + assert get_deploy_environment() == "first-line" + + def test_crlf_with_surrounding_whitespace(self, tmp_path): + _write_env_file(tmp_path, " local-desktop2-standalone \r\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_lone_cr_line_ending(self, tmp_path): + # Classic-Mac / some legacy editors use a bare CR. + # Universal-newlines decoding treats it as a line terminator too. + _write_env_file(tmp_path, "local-desktop2-standalone\r") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_empty_file_falls_back_to_default(self, tmp_path): + _write_env_file(tmp_path, "") + assert get_deploy_environment() == "local-git" + + def test_empty_after_whitespace_strip_falls_back_to_default(self, tmp_path): + _write_env_file(tmp_path, " \n") + assert get_deploy_environment() == "local-git" + + def test_strips_control_chars_within_first_line(self, tmp_path): + # Embedded NUL/control chars in the value should be stripped + # (header-injection / smuggling protection). + _write_env_file(tmp_path, "abc\x00\x07xyz\n") + assert get_deploy_environment() == "abcxyz" + + def test_strips_non_ascii_characters(self, tmp_path): + _write_env_file(tmp_path, "café-é\n") + assert get_deploy_environment() == "caf-" + + def test_caps_read_at_128_bytes(self, tmp_path): + # A single huge line with no newline must not be fully read into memory. + huge = "x" * 10_000 + _write_env_file(tmp_path, huge) + result = get_deploy_environment() + assert result == "x" * 128 + + def test_result_is_cached_across_calls(self, tmp_path): + path = _write_env_file(tmp_path, "first_value\n") + assert get_deploy_environment() == "first_value" + # Overwrite the file — cached value should still be returned. + with open(path, "w", encoding="utf-8") as f: + f.write("second_value\n") + assert get_deploy_environment() == "first_value" + + def test_unreadable_file_falls_back_to_default(self, tmp_path, monkeypatch): + _write_env_file(tmp_path, "should_not_be_used\n") + + def _boom(*args, **kwargs): + raise OSError("simulated read failure") + + monkeypatch.setattr("builtins.open", _boom) + assert get_deploy_environment() == "local-git" From 9aef025fb0c67c9982e6ea52e7a866fd23625c5e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 May 2026 20:45:48 -0700 Subject: [PATCH 060/102] Document core release frequency is now ~2 weeks. (#13710) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a3bd3ba0a..0fd317d0a 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks. - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. - Minor versions will be used for releases off the master branch. - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. From fed8d5efa6b70d5b24c4c33cb643bfccc39d45b5 Mon Sep 17 00:00:00 2001 From: Talmaj Date: Tue, 5 May 2026 06:01:22 +0200 Subject: [PATCH 061/102] feat: Auto-regressive video generation (CORE-25) (#13082) --- comfy/k_diffusion/sampling.py | 99 ++++++++++++ comfy/ldm/wan/ar_model.py | 276 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 8 + comfy/supported_models.py | 20 +++ comfy_extras/nodes_ar_video.py | 84 ++++++++++ nodes.py | 1 + 6 files changed, 488 insertions(+) create mode 100644 comfy/ldm/wan/ar_model.py create mode 100644 comfy_extras/nodes_ar_video.py diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..d33bc7199 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + + +@torch.no_grad() +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None, + num_frame_per_block=1): + """ + Autoregressive video sampler: block-by-block denoising with KV cache + and flow-match re-noising for Causal Forcing / Self-Forcing models. + + Requires a Causal-WAN compatible model (diffusion_model must expose + init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W]. + + All AR-loop parameters are passed via the SamplerARVideo node, not read + from the checkpoint or transformer_options. + """ + extra_args = {} if extra_args is None else extra_args + model_options = extra_args.get("model_options", {}) + transformer_options = model_options.get("transformer_options", {}) + + if x.ndim != 5: + raise ValueError( + f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. " + "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)." + ) + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + + if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + raise TypeError( + "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " + "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " + "does not support this interface — choose a different sampler." + ) + + seed = extra_args.get("seed", 0) + + bs, c, lat_t, lat_h, lat_w = x.shape + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + device = x.device + model_dtype = inner_model.get_dtype() + + kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) + + output = torch.zeros_like(x) + s_in = x.new_ones([x.shape[0]]) + current_start_frame = 0 + num_sigma_steps = len(sigmas) - 1 + total_real_steps = num_blocks * num_sigma_steps + step_count = 0 + + try: + for block_idx in trange(num_blocks, disable=disable): + bf = min(num_frame_per_block, lat_t - current_start_frame) + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] + + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state + + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + + if callback is not None: + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) + + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + + step_count += 1 + + output[:, :, fs:fe] = noisy_input + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) + + current_start_frame += bf + finally: + transformer_options.pop("ar_state", None) + + return output diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py new file mode 100644 index 000000000..d72f53602 --- /dev/null +++ b/comfy/ldm/wan/ar_model.py @@ -0,0 +1,276 @@ +""" +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for +autoregressive (frame-by-frame) video generation via Causal Forcing. + +Weight-compatible with the standard WanModel -- same layer names, same shapes. +The difference is purely in the forward pass: this model processes one temporal +block at a time and maintains a KV cache across blocks. + +Reference: https://github.com/thu-ml/Causal-Forcing +""" + +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import ( + sinusoidal_embedding_1d, + repeat_e, + WanModel, + WanAttentionBlock, +) +import comfy.ldm.common_dit +import comfy.model_management + + +class CausalWanSelfAttention(nn.Module): + """Self-attention with KV cache support for autoregressive inference.""" + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, + eps=1e-6, operation_settings={}): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + + def forward(self, x, freqs, kv_cache=None, transformer_options={}): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) + v = self.v(x).view(b, s, n, d) + + if kv_cache is None: + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v.view(b, s, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + else: + end = kv_cache["end"] + new_end = end + s + + # Roped K and plain V go into cache + kv_cache["k"][:, end:new_end] = k + kv_cache["v"][:, end:new_end] = v + kv_cache["end"] = new_end + + x = optimized_attention( + q.view(b, s, n * d), + kv_cache["k"][:, :new_end].view(b, new_end, n * d), + kv_cache["v"][:, :new_end].view(b, new_end, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + + x = self.o(x) + return x + + +class CausalWanAttentionBlock(WanAttentionBlock): + """Transformer block with KV-cached self-attention and cross-attention caching.""" + + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + self.self_attn = CausalWanSelfAttention( + dim, num_heads, window_size, qk_norm, eps, + operation_settings=operation_settings) + + def forward(self, x, e, freqs, context, context_img_len=257, + kv_cache=None, crossattn_cache=None, transformer_options={}): + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + + # Self-attention with optional KV cache + x = x.contiguous() + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, kv_cache=kv_cache, transformer_options=transformer_options) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y + + # Cross-attention with optional caching + if crossattn_cache is not None and crossattn_cache.get("is_init"): + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) + x_ca = optimized_attention( + q, crossattn_cache["k"], crossattn_cache["v"], + heads=self.num_heads, transformer_options=transformer_options) + x = x + self.cross_attn.o(x_ca) + else: + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if crossattn_cache is not None: + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) + crossattn_cache["v"] = self.cross_attn.v(context) + crossattn_cache["is_init"] = True + + # FFN + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class CausalWanModel(WanModel): + """ + Wan 2.1 diffusion backbone with causal KV-cache support. + + Same weight structure as WanModel -- loads identical state dicts. + Adds forward_block() for frame-by-frame autoregressive inference. + """ + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + image_model=None, + device=None, + dtype=None, + operations=None): + super().__init__( + model_type=model_type, patch_size=patch_size, text_len=text_len, + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, + wan_attn_block_class=CausalWanAttentionBlock, + device=device, dtype=dtype, operations=operations) + + def forward_block(self, x, timestep, context, start_frame, + kv_caches, crossattn_caches, clip_fea=None): + """ + Forward one temporal block for autoregressive inference. + + Args: + x: [B, C, block_frames, H, W] input latent for the current block + timestep: [B, block_frames] per-frame timesteps + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) + start_frame: temporal frame index for RoPE offset + kv_caches: list of per-layer KV cache dicts + crossattn_caches: list of per-layer cross-attention cache dicts + clip_fea: optional CLIP features for I2V + + Returns: + flow_pred: [B, C_out, block_frames, H, W] flow prediction + """ + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + bs, c, t, h, w = x.shape + + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # Per-frame time embedding + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # Text embedding (reuses crossattn_cache after first block) + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + # RoPE for current block's temporal position + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) + + # Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + kv_cache=kv_caches[i], + crossattn_cache=crossattn_caches[i]) + + # Head + x = self.head(x, e) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] + + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): + """Create fresh KV caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({ + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "end": 0, + }) + return caches + + def init_crossattn_caches(self, batch_size, device, dtype): + """Create fresh cross-attention caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({"is_init": False}) + return caches + + def reset_kv_caches(self, kv_caches): + """Reset KV caches to empty (reuse allocated memory).""" + for cache in kv_caches: + cache["end"] = 0 + + def reset_crossattn_caches(self, crossattn_caches): + """Reset cross-attention caches.""" + for cache in crossattn_caches: + cache["is_init"] = False + + @property + def head_dim(self): + return self.dim // self.num_heads + + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + ar_state = transformer_options.get("ar_state") + if ar_state is not None: + bs = x.shape[0] + block_frames = x.shape[2] + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) + return self.forward_block( + x=x, timestep=t_per_frame, context=context, + start_frame=ar_state["start_frame"], + kv_caches=ar_state["kv_caches"], + crossattn_caches=ar_state["crossattn_caches"], + clip_fea=clip_fea, + ) + + return super().forward(x, timestep, context, clip_fea=clip_fea, + time_dim_concat=time_dim_concat, + transformer_options=transformer_options, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index b61a2aa09..57a1e44d2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.wan.ar_model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1365,6 +1366,13 @@ class WAN21(BaseModel): return out +class WAN21_CausalAR(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.wan.ar_model.CausalWanModel) + self.image_to_video = False + + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e6c17fb98..dff40461f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) +class WAN21_CausalAR_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "causal_ar": True, + } + + sampling_settings = { + "shift": 5.0, + } + + def __init__(self, unet_config): + super().__init__(unet_config) + self.unet_config.pop("causal_ar", None) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.WAN21_CausalAR(self, device=device) + + class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1929,6 +1948,7 @@ models = [ ZImage, Lumina2, WAN22_T2V, + WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py new file mode 100644 index 000000000..09ee886fd --- /dev/null +++ b/comfy_extras/nodes_ar_video.py @@ -0,0 +1,84 @@ +""" +ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). + - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors + - SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop +""" + +import torch +from typing_extensions import override + +import comfy.model_management +import comfy.samplers +from comfy_api.latest import ComfyExtension, io + + +class EmptyARVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyARVideoLatent", + category="latent/video", + inputs=[ + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class SamplerARVideo(io.ComfyNode): + """Sampler for autoregressive video models (Causal Forcing, Self-Forcing). + + All AR-loop parameters are owned by this node so they live in the workflow. + Add new widgets here as the AR sampler grows new options. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerARVideo", + display_name="Sampler AR Video", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input( + "num_frame_per_block", + default=1, min=1, max=64, + tooltip="Frames per autoregressive block. 1 = framewise, " + "3 = chunkwise. Must match the checkpoint's training mode.", + ), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, num_frame_per_block) -> io.NodeOutput: + extra_options = { + "num_frame_per_block": num_frame_per_block, + } + return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options)) + + +class ARVideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyARVideoLatent, + SamplerARVideo, + ] + + +async def comfy_entrypoint() -> ARVideoExtension: + return ARVideoExtension() diff --git a/nodes.py b/nodes.py index eebc2fa76..1e41b2ae0 100644 --- a/nodes.py +++ b/nodes.py @@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_ar_video.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py", From 8d752113007178d6fbdf09ef01473c5233802cf2 Mon Sep 17 00:00:00 2001 From: Alvin Tang Date: Tue, 5 May 2026 20:29:11 +0800 Subject: [PATCH 062/102] fix: SplitImageToTileList and ImageMergeTileList to use tile_height for vertical stride minimum (#12882) --- comfy_extras/nodes_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 68916ab75..1ac740d1d 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -716,7 +716,7 @@ class SplitImageToTileList(IO.ComfyNode): def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] stride_x = round(max(tile_width * 0.25, tile_width - overlap)) - stride_y = round(max(tile_width * 0.25, tile_height - overlap)) + stride_y = round(max(tile_height * 0.25, tile_height - overlap)) y = 0 while y < height: From c55ff8524373940a404a130394fa7078ff64f9cd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 5 May 2026 16:49:07 +0300 Subject: [PATCH 063/102] feat(api-nodes): add Luma UNI-1 models (#13614) Signed-off-by: bigcat88 Co-authored-by: Alexis Rolland --- comfy_api_nodes/apis/luma.py | 38 ++++- comfy_api_nodes/nodes_luma.py | 306 +++++++++++++++++++++++++++++++++- 2 files changed, 330 insertions(+), 14 deletions(-) diff --git a/comfy_api_nodes/apis/luma.py b/comfy_api_nodes/apis/luma.py index 632c4ab96..8c6db2022 100644 --- a/comfy_api_nodes/apis/luma.py +++ b/comfy_api_nodes/apis/luma.py @@ -1,15 +1,12 @@ from __future__ import annotations - -import torch - from enum import Enum from typing import Optional, Union +import torch from pydantic import BaseModel, Field, confloat - class LumaIO: LUMA_REF = "LUMA_REF" LUMA_CONCEPTS = "LUMA_CONCEPTS" @@ -183,13 +180,13 @@ class LumaAssets(BaseModel): class LumaImageRef(BaseModel): - '''Used for image gen''' + """Used for image gen""" url: str = Field(..., description='The URL of the image reference') weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') class LumaImageReference(BaseModel): - '''Used for video gen''' + """Used for video gen""" type: Optional[str] = Field('image', description='Input type, defaults to image') url: str = Field(..., description='The URL of the image') @@ -251,3 +248,32 @@ class LumaGeneration(BaseModel): assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') model: str = Field(..., description='The model used for the generation') request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") + + +class Luma2ImageRef(BaseModel): + url: str | None = None + data: str | None = None + media_type: str | None = None + + +class Luma2GenerationRequest(BaseModel): + prompt: str = Field(..., min_length=1, max_length=6000) + model: str | None = None + type: str | None = None + aspect_ratio: str | None = None + style: str | None = None + output_format: str | None = None + web_search: bool | None = None + image_ref: list[Luma2ImageRef] | None = None + source: Luma2ImageRef | None = None + + +class Luma2Generation(BaseModel): + id: str | None = None + type: str | None = None + state: str | None = None + model: str | None = None + created_at: str | None = None + output: list[LumaImageReference] | None = None + failure_reason: str | None = None + failure_code: str | None = None diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 9ed6cd299..d92a7c382 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,10 +1,11 @@ -from typing import Optional - import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.luma import ( + Luma2Generation, + Luma2GenerationRequest, + Luma2ImageRef, LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -30,6 +31,7 @@ from comfy_api_nodes.util import ( download_url_to_video_output, poll_op, sync_op, + upload_image_to_comfyapi, upload_images_to_comfyapi, validate_string, ) @@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode): aspect_ratio: str, seed, style_image_weight: float, - image_luma_ref: Optional[LumaReferenceChain] = None, - style_image: Optional[torch.Tensor] = None, - character_image: Optional[torch.Tensor] = None, + image_luma_ref: LumaReferenceChain | None = None, + style_image: torch.Tensor | None = None, + character_image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) # handle image_luma_ref @@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): duration: str, loop: bool, seed, - luma_concepts: Optional[LumaConceptChain] = None, + luma_concepts: LumaConceptChain | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None @@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): ], is_api_node=True, price_badge=PRICE_BADGE_VIDEO, - ) @classmethod @@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( ) +def _luma2_uni1_common_inputs(max_image_refs: int) -> list: + return [ + IO.Combo.Input( + "style", + options=["auto", "manga"], + default="auto", + tooltip="Style preset. 'auto' picks based on the prompt; " + "'manga' applies a manga/anime aesthetic and requires a portrait " + "aspect ratio (2:3, 9:16, 1:2, 1:3).", + ), + IO.Boolean.Input( + "web_search", + default=False, + tooltip="Search the web for visual references before generating.", + ), + IO.Autogrow.Input( + "image_ref", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, max_image_refs + 1)], + min=0, + ), + optional=True, + tooltip=f"Up to {max_image_refs} reference images for style/content guidance.", + ), + ] + + +async def _luma2_upload_image_refs( + cls: type[IO.ComfyNode], + refs: dict | None, + max_count: int, +) -> list[Luma2ImageRef] | None: + if not refs: + return None + out: list[Luma2ImageRef] = [] + for key in refs: + url = await upload_image_to_comfyapi(cls, refs[key]) + out.append(Luma2ImageRef(url=url)) + if len(out) > max_count: + raise ValueError(f"Maximum {max_count} reference images are allowed.") + return out or None + + +async def _luma2_submit_and_poll( + cls: type[IO.ComfyNode], + request: Luma2GenerationRequest, +) -> Input.Image: + initial = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), + response_model=Luma2Generation, + data=request, + ) + if not initial.id: + raise RuntimeError("Luma 2 API did not return a generation id.") + final = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), + response_model=Luma2Generation, + status_extractor=lambda r: r.state, + progress_extractor=lambda r: None, + ) + if not final.output: + msg = final.failure_reason or "no output returned" + raise RuntimeError(f"Luma 2 generation failed: {msg}") + url = final.output[0].url + if not url: + raise RuntimeError("Luma 2 generation completed without an output URL.") + return await download_url_to_image_tensor(url) + + +class LumaImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageNode2", + display_name="Luma UNI-1 Image", + category="api node/image/Luma", + description="Generate images from text using the Luma UNI-1 model.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. 1–6000 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "uni-1", + [ + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "3:1", + "2:1", + "16:9", + "3:2", + "1:1", + "2:3", + "9:16", + "1:2", + "1:3", + ], + default="auto", + tooltip="Output image aspect ratio. 'auto' lets " + "the model pick based on the prompt.", + ), + *_luma2_uni1_common_inputs(max_image_refs=9), + ], + ), + IO.DynamicCombo.Option( + "uni-1-max", + [ + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "3:1", + "2:1", + "16:9", + "3:2", + "1:1", + "2:3", + "9:16", + "1:2", + "1:3", + ], + default="auto", + tooltip="Output image aspect ratio. 'auto' lets " + "the model pick based on the prompt.", + ), + *_luma2_uni1_common_inputs(max_image_refs=9), + ], + ), + ], + tooltip="Model to use for generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]), + expr=""" + ( + $m := widgets.model; + $refs := $lookup(inputGroups, "model.image_ref"); + $base := $m = "uni-1-max" ? 0.1 : 0.0404; + {"type":"usd","usd": $round($base + 0.003 * $refs, 4)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=6000) + aspect_ratio = model["aspect_ratio"] + style = model["style"] + allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"} + if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios: + raise ValueError( + f"'manga' style requires a portrait aspect ratio " + f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'." + ) + request = Luma2GenerationRequest( + prompt=prompt, + model=model["model"], + type="image", + aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None, + style=style if style != "auto" else None, + output_format="png", + web_search=model["web_search"], + image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), + ) + return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + + +class LumaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageEditNode2", + display_name="Luma UNI-1 Image Edit", + category="api node/image/Luma", + description="Edit an existing image with a text prompt using the Luma UNI-1 model.", + inputs=[ + IO.Image.Input( + "source", + tooltip="Source image to edit.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Description of the desired edit. 1–6000 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "uni-1", + _luma2_uni1_common_inputs(max_image_refs=8), + ), + IO.DynamicCombo.Option( + "uni-1-max", + _luma2_uni1_common_inputs(max_image_refs=8), + ), + ], + tooltip="Model to use for editing.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]), + expr=""" + ( + $m := widgets.model; + $refs := $lookup(inputGroups, "model.image_ref"); + $base := $m = "uni-1-max" ? 0.103 : 0.0434; + {"type":"usd","usd": $round($base + 0.003 * $refs, 4)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + source: Input.Image, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=6000) + request = Luma2GenerationRequest( + prompt=prompt, + model=model["model"], + type="image_edit", + source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)), + style=model["style"] if model["style"] != "auto" else None, + output_format="png", + web_search=model["web_search"], + image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), + ) + return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension): LumaImageToVideoGenerationNode, LumaReferenceNode, LumaConceptsNode, + LumaImageNode, + LumaImageEditNode, ] From 6917bce1281232a83c079a38540fca30d0fc279e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 5 May 2026 16:53:19 +0300 Subject: [PATCH 064/102] [Partner Nodes] add Gpt 5.5 and 5.5-pro LLM models (#13673) * feat(api-nodes): add Gpt 5.5 and 5.5-pro LLM models Signed-off-by: bigcat88 --- comfy_api_nodes/apis/openai.py | 6 +++--- comfy_api_nodes/nodes_openai.py | 26 +++++++++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py index b85ef252b..bee75d639 100644 --- a/comfy_api_nodes/apis/openai.py +++ b/comfy_api_nodes/apis/openai.py @@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel): instructions: str | None = Field(None) max_output_tokens: int | None = Field(None) model: str | None = Field(None) - temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0) top_p: float | None = Field( - 1, + None, description="Controls diversity of the response via nucleus sampling", ge=0.0, le=1.0, ) - truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'") class ResponseProperties(BaseModel): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 21fe470ce..daed495da 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"" class SupportedOpenAIModel(str, Enum): - o4_mini = "o4-mini" - o1 = "o1" - o3 = "o3" - o1_pro = "o1-pro" - gpt_4_1 = "gpt-4.1" - gpt_4_1_mini = "gpt-4.1-mini" - gpt_4_1_nano = "gpt-4.1-nano" + gpt_5_5_pro = "gpt-5.5-pro" + gpt_5_5 = "gpt-5.5" gpt_5 = "gpt-5" gpt_5_mini = "gpt-5-mini" gpt_5_nano = "gpt-5-nano" + gpt_4_1 = "gpt-4.1" + gpt_4_1_mini = "gpt-4.1-mini" + gpt_4_1_nano = "gpt-4.1-nano" + o4_mini = "o4-mini" + o3 = "o3" + o1_pro = "o1-pro" + o1 = "o1" async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: @@ -739,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode): "usd": [0.002, 0.008], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gpt-5.5-pro") ? { + "type": "list_usd", + "usd": [0.03, 0.18], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5.5") ? { + "type": "list_usd", + "usd": [0.005, 0.03], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : $contains($m, "gpt-5-nano") ? { "type": "list_usd", "usd": [0.00005, 0.0004], From d794b62939ed82c88160d569854c41a42186bd9a Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Tue, 5 May 2026 22:57:27 +0900 Subject: [PATCH 065/102] Update workflow templates to v0.9.69 (#13714) * chore: update workflow templates to v0.9.69 * Update comfyui-workflow-templates to version 0.9.70 * Downgrade comfyui-workflow-templates to 0.9.69 --------- Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 32826e25a..e9415f2fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.68 +comfyui-workflow-templates==0.9.69 comfyui-embedded-docs==0.4.4 torch torchsde From 639f631a0848f27497c6a29d2fb7d06c921c744d Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Tue, 5 May 2026 22:31:24 +0800 Subject: [PATCH 066/102] chore: Update display names and categories for text nodes (CORE-155) (#13712) --- comfy_extras/nodes_primitive.py | 6 ++- comfy_extras/nodes_string.py | 66 ++++++++++++++++----------------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 3c8f90b19..33373266b 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -9,7 +9,8 @@ class String(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PrimitiveString", - display_name="String", + search_aliases=["text", "string", "text box", "prompt"], + display_name="Text String", category="utils/primitive", inputs=[ io.String.Input("value"), @@ -27,7 +28,8 @@ class StringMultiline(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PrimitiveStringMultiline", - display_name="String (Multiline)", + search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], + display_name="Text String (Multiline)", category="utils/primitive", essentials_category="Basics", inputs=[ diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 604076c4e..925a40da8 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -10,9 +10,9 @@ class StringConcatenate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringConcatenate", - display_name="Text Concatenate", - category="utils/string", - search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], + search_aliases=["concatenate", "text concat", "join text", "merge text", "combine strings", "string concat", "append text", "combine text"], + display_name="Concatenate Text", + category="text", inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -33,9 +33,9 @@ class StringSubstring(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringSubstring", - search_aliases=["Substring", "extract text", "text portion"], - display_name="Text Substring", - category="utils/string", + search_aliases=["substring", "extract text", "text portion"], + display_name="Substring", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Int.Input("start"), @@ -58,7 +58,7 @@ class StringLength(io.ComfyNode): node_id="StringLength", search_aliases=["character count", "text size", "string length"], display_name="Text Length", - category="utils/string", + category="text", inputs=[ io.String.Input("string", multiline=True), ], @@ -77,9 +77,9 @@ class CaseConverter(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CaseConverter", - search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"], - display_name="Text Case Converter", - category="utils/string", + search_aliases=["case converter", "text case", "uppercase", "lowercase", "capitalize"], + display_name="Convert Text Case", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), @@ -110,9 +110,9 @@ class StringTrim(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringTrim", - search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"], - display_name="Text Trim", - category="utils/string", + search_aliases=["trim", "clean whitespace", "remove whitespace", "remove spaces","strip"], + display_name="Trim Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Combo.Input("mode", options=["Both", "Left", "Right"]), @@ -141,9 +141,9 @@ class StringReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringReplace", - search_aliases=["Replace", "find and replace", "substitute", "swap text"], - display_name="Text Replace", - category="utils/string", + search_aliases=["replace", "find and replace", "substitute", "swap text"], + display_name="Replace Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("find", multiline=True), @@ -164,9 +164,9 @@ class StringContains(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringContains", - search_aliases=["Contains", "text includes", "string includes"], - display_name="Text Contains", - category="utils/string", + search_aliases=["contains", "text includes", "string includes"], + display_name="Contains Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("substring", multiline=True), @@ -192,9 +192,9 @@ class StringCompare(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringCompare", - search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"], - display_name="Text Compare", - category="utils/string", + search_aliases=["compare", "text match", "string equals", "starts with", "ends with"], + display_name="Compare Text", + category="text", inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -228,9 +228,9 @@ class RegexMatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexMatch", - search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"], - display_name="Text Match", - category="utils/string", + search_aliases=["regex match", "regex", "pattern match", "text contains", "string match"], + display_name="Match Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), @@ -269,9 +269,9 @@ class RegexExtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexExtract", - search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"], - display_name="Text Extract Substring", - category="utils/string", + search_aliases=["regex extract", "regex", "pattern extract", "text parser", "parse text"], + display_name="Extract Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), @@ -344,9 +344,9 @@ class RegexReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexReplace", - search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"], - display_name="Text Replace (Regex)", - category="utils/string", + search_aliases=["regex replace", "regex", "pattern replace", "substitution"], + display_name="Replace Text (Regex)", + category="text", description="Find and replace text using regex patterns.", inputs=[ io.String.Input("string", multiline=True), @@ -381,8 +381,8 @@ class JsonExtractString(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="JsonExtractString", - display_name="Extract String from JSON", - category="utils/string", + display_name="Extract Text from JSON", + category="text", search_aliases=["json", "extract json", "parse json", "json value", "read json"], inputs=[ io.String.Input("json_string", multiline=True), From ea6880b04b88629b9dd07774298bdffea6923f9b Mon Sep 17 00:00:00 2001 From: THE MACHINE Date: Wed, 6 May 2026 02:00:03 +0800 Subject: [PATCH 067/102] Fix Content-Disposition header missing 'attachment;' prefix (#13093) Add missing 'attachment;' directive to Content-Disposition headers in server.py to ensure browsers properly download files instead of attempting to display them inline. Fixes 4 instances in the file download endpoint. Co-authored-by: guill --- server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index 2f3b438bb..0e85635d3 100644 --- a/server.py +++ b/server.py @@ -560,7 +560,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type=f'image/{image_format}', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) if 'channel' not in request.rel_url.query: channel = 'rgba' @@ -580,7 +580,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) elif channel == 'a': with Image.open(file) as img: @@ -597,7 +597,7 @@ class PromptServer(): alpha_buffer.seek(0) return web.Response(body=alpha_buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) else: # Use the content type from asset resolution if available, # otherwise guess from the filename. @@ -614,7 +614,7 @@ class PromptServer(): return web.FileResponse( file, headers={ - "Content-Disposition": f"filename=\"{filename}\"", + "Content-Disposition": f"attachment; filename=\"{filename}\"", "Content-Type": content_type } ) From 41d73ad18094ddf9c91e40f548a52d013d07e894 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Tue, 5 May 2026 12:33:16 -0600 Subject: [PATCH 068/102] fix(audio): drop sample_rate key from LTXVEmptyLatentAudio (CORE-157) (#13716) --- comfy_extras/nodes_lt_audio.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3ec635c75..2c1f63afb 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -147,7 +147,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode): z_channels = audio_vae.latent_channels audio_freq = audio_vae.first_stage_model.latent_frequency_bins - sampling_rate = int(audio_vae.first_stage_model.sample_rate) num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) @@ -159,7 +158,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode): return io.NodeOutput( { "samples": audio_latents, - "sample_rate": sampling_rate, "type": "audio", } ) From 1ac60da2c9c8f83654204b2a1db13908cf7614f7 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 5 May 2026 13:21:36 -0700 Subject: [PATCH 069/102] Add Spectral lint CI gate for openapi.yaml (#13410) * Add Spectral lint CI gate for openapi.yaml Adds a blocking Spectral lint check that runs on PRs touching openapi.yaml or the ruleset itself. The ruleset mirrors the one used for other Comfy-Org service specs: spectral:oas plus conventions for snake_case properties, camelCase operationIds, and response/schema shape. Gate runs at --fail-severity=error, which the spec currently passes with zero errors (a small number of non-blocking warnings/hints remain for WebSocket 101 responses, the existing loose error schema, and two snake_case wire fields). * ci: set least-privilege contents:read permissions on openapi-lint workflow Per CodeRabbit review on #13410. The job only checks out the repo and runs Spectral, so contents:read is sufficient and avoids inheriting any permissive repo/org default token scope. --------- Co-authored-by: guill --- .github/workflows/openapi-lint.yml | 31 ++++++++++ .spectral.yaml | 91 ++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 .github/workflows/openapi-lint.yml create mode 100644 .spectral.yaml diff --git a/.github/workflows/openapi-lint.yml b/.github/workflows/openapi-lint.yml new file mode 100644 index 000000000..be949de2a --- /dev/null +++ b/.github/workflows/openapi-lint.yml @@ -0,0 +1,31 @@ +name: OpenAPI Lint + +on: + pull_request: + paths: + - 'openapi.yaml' + - '.spectral.yaml' + - '.github/workflows/openapi-lint.yml' + +permissions: + contents: read + +jobs: + spectral: + name: Run Spectral + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Install Spectral + run: npm install -g @stoplight/spectral-cli@6 + + - name: Lint openapi.yaml + run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error diff --git a/.spectral.yaml b/.spectral.yaml new file mode 100644 index 000000000..4bb4a4a94 --- /dev/null +++ b/.spectral.yaml @@ -0,0 +1,91 @@ +extends: + - spectral:oas + +# Severity levels: error, warn, info, hint, off +# Rules from the built-in "spectral:oas" ruleset are active by default. +# Below we tune severity and add custom rules for our conventions. +# +# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the +# organization are linted against a single consistent standard. + +rules: + # ----------------------------------------------------------------------- + # Built-in rule severity overrides + # ----------------------------------------------------------------------- + operation-operationId: error + operation-description: warn + operation-tag-defined: error + info-contact: off + info-description: warn + no-eval-in-markdown: error + no-$ref-siblings: error + + # ----------------------------------------------------------------------- + # Custom rules: naming conventions + # ----------------------------------------------------------------------- + + # Property names should be snake_case + property-name-snake-case: + description: Property names must be snake_case + severity: warn + given: "$.components.schemas.*.properties[*]~" + then: + function: pattern + functionOptions: + match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$" + + # Operation IDs should be camelCase + operation-id-camel-case: + description: Operation IDs must be camelCase + severity: warn + given: "$.paths.*.*.operationId" + then: + function: pattern + functionOptions: + match: "^[a-z][a-zA-Z0-9]*$" + + # ----------------------------------------------------------------------- + # Custom rules: response conventions + # ----------------------------------------------------------------------- + + # Error responses (4xx, 5xx) should use a consistent shape + error-response-schema: + description: Error responses should reference a standard error schema + severity: hint + given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema" + then: + field: "$ref" + function: truthy + + # All 2xx responses with JSON body should have a schema + response-schema-defined: + description: Success responses with JSON content should define a schema + severity: warn + given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']" + then: + field: schema + function: truthy + + # ----------------------------------------------------------------------- + # Custom rules: best practices + # ----------------------------------------------------------------------- + + # Path parameters must have a description + path-param-description: + description: Path parameters should have a description + severity: warn + given: + - "$.paths.*.parameters[?(@.in == 'path')]" + - "$.paths.*.*.parameters[?(@.in == 'path')]" + then: + field: description + function: truthy + + # Schemas should have a description + schema-description: + description: Component schemas should have a description + severity: hint + given: "$.components.schemas.*" + then: + field: description + function: truthy From 431fadb520bbd2d18cbbd4067e06222301f1b4fe Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 5 May 2026 13:58:32 -0700 Subject: [PATCH 070/102] fix(api-io): serialize MultiCombo multi_select as object config (#13484) * fix(api-io): serialize MultiCombo multi_select as object config * fix: remove dead code and redundant top-level keys from MultiCombo serialization * fix: correct skip warning to mention comfy_entrypoint, remove nonexistent NODES_LIST * fix: validate MultiCombo list values against options individually * fix: gate multiselect validation on schema config, improve error message, add tests --------- Co-authored-by: Ni-zav Co-authored-by: guill --- comfy_api/latest/_io.py | 13 ++-- execution.py | 9 ++- nodes.py | 2 +- .../multicombo_serialization_test.py | 78 +++++++++++++++++++ 4 files changed, 93 insertions(+), 9 deletions(-) create mode 100644 tests-unit/comfy_api_test/multicombo_serialization_test.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4942ed46c..e50266bc5 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -395,7 +395,6 @@ class Combo(ComfyTypeIO): @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' - # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect Type = list[str] class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, @@ -408,12 +407,14 @@ class MultiCombo(ComfyTypeI): self.default: list[str] def as_dict(self): - to_return = super().as_dict() | prune_dict({ - "multi_select": self.multiselect, - "placeholder": self.placeholder, - "chip": self.chip, + # Frontend expects `multi_select` to be an object config (not a boolean). + # Keep top-level `multiselect` from Combo.Input for backwards compatibility. + return super().as_dict() | prune_dict({ + "multi_select": prune_dict({ + "placeholder": self.placeholder, + "chip": self.chip, + }), }) - return to_return @comfytype(io_type="IMAGE") class Image(ComfyTypeIO): diff --git a/execution.py b/execution.py index 654db8426..f37d0360d 100644 --- a/execution.py +++ b/execution.py @@ -1019,7 +1019,12 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): combo_options = extra_info.get("options", []) else: combo_options = input_type - if val not in combo_options: + is_multiselect = extra_info.get("multiselect", False) + if is_multiselect and isinstance(val, list): + invalid_vals = [v for v in val if v not in combo_options] + else: + invalid_vals = [val] if val not in combo_options else [] + if invalid_vals: input_config = info list_info = "" @@ -1034,7 +1039,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): error = { "type": "value_not_in_list", "message": "Value not in list", - "details": f"{x}: '{val}' not in {list_info}", + "details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}", "extra_info": { "input_name": x, "input_config": input_config, diff --git a/nodes.py b/nodes.py index 1e41b2ae0..cf61d9df0 100644 --- a/nodes.py +++ b/nodes.py @@ -2262,7 +2262,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") return False else: - logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).") return False except Exception as e: logging.warning(traceback.format_exc()) diff --git a/tests-unit/comfy_api_test/multicombo_serialization_test.py b/tests-unit/comfy_api_test/multicombo_serialization_test.py new file mode 100644 index 000000000..421c65a0d --- /dev/null +++ b/tests-unit/comfy_api_test/multicombo_serialization_test.py @@ -0,0 +1,78 @@ +from comfy_api.latest._io import Combo, MultiCombo + + +def test_multicombo_serializes_multi_select_as_object(): + multi_combo = MultiCombo.Input( + id="providers", + options=["a", "b", "c"], + default=["a"], + ) + + serialized = multi_combo.as_dict() + + assert serialized["multiselect"] is True + assert "multi_select" in serialized + assert serialized["multi_select"] == {} + + +def test_multicombo_serializes_multi_select_with_placeholder_and_chip(): + multi_combo = MultiCombo.Input( + id="providers", + options=["a", "b", "c"], + default=["a"], + placeholder="Select providers", + chip=True, + ) + + serialized = multi_combo.as_dict() + + assert serialized["multiselect"] is True + assert serialized["multi_select"] == { + "placeholder": "Select providers", + "chip": True, + } + + +def test_combo_does_not_serialize_multiselect(): + """Regular Combo should not have multiselect in its serialized output.""" + combo = Combo.Input( + id="choice", + options=["a", "b", "c"], + ) + + serialized = combo.as_dict() + + # Combo sets multiselect=False, but prune_dict keeps False (not None), + # so it should be present but False + assert serialized.get("multiselect") is False + assert "multi_select" not in serialized + + +def _validate_combo_values(val, combo_options, is_multiselect): + """Reproduce the validation logic from execution.py for testing.""" + if is_multiselect and isinstance(val, list): + return [v for v in val if v not in combo_options] + else: + return [val] if val not in combo_options else [] + + +def test_multicombo_validation_accepts_valid_list(): + options = ["a", "b", "c"] + assert _validate_combo_values(["a", "b"], options, True) == [] + + +def test_multicombo_validation_rejects_invalid_values(): + options = ["a", "b", "c"] + assert _validate_combo_values(["a", "x"], options, True) == ["x"] + + +def test_multicombo_validation_accepts_empty_list(): + options = ["a", "b", "c"] + assert _validate_combo_values([], options, True) == [] + + +def test_combo_validation_rejects_list_even_with_valid_items(): + """A regular Combo should not accept a list value.""" + options = ["a", "b", "c"] + invalid = _validate_combo_values(["a", "b"], options, False) + assert len(invalid) > 0 From 89014792c966b04bf18f7ba62aee5169f9094e84 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 5 May 2026 14:20:09 -0700 Subject: [PATCH 071/102] feat: add cloud-specific fields to OSS openapi.yaml as nullable (#13623) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add cloud-specific fields to OSS openapi.yaml as nullable Add cross-runtime fields with x-runtime: [cloud] extension and [cloud-only] description prefix per the convention established in BE-613. All new fields are nullable and not in required arrays, so they are purely additive. /api/features response: - max_upload_size (integer, int64) - free_tier_credits (integer, int32) - posthog_api_host (string, uri) - max_concurrent_jobs (integer, int32) - workflow_templates_version (string) - workflow_templates_source (string, enum) PromptRequest schema: - workflow_id (string, uuid) - workflow_version_id (string, uuid) POST /api/assets: - id field (uuid) on multipart/form-data for idempotent creation - application/json alternate content-type for URL-based uploads POST /api/assets/from-hash: - mime_type (string) to preserve type without re-inspection PUT /api/assets/{id}: - mime_type (string) for overriding auto-detection GET /api/assets additional query parameters: - job_ids (string) — filter by associated job UUIDs - include_public (boolean) — include workspace-public assets - asset_hash (string) — filter by exact content hash Resolves: BE-613 Blocks: BE-364, BE-361, BE-363 Co-authored-by: Matt Miller * fix(openapi): address CodeRabbit feedback (BE-613) - max_upload_size is set in both runtimes via SERVER_FEATURE_FLAGS; drop the cloud-only / nullable tagging. - Require `url` on the application/json POST /api/assets body so the contract is enforceable by validators and codegen. --------- Co-authored-by: Matt Miller --- openapi.yaml | 122 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 11 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 30f85b6ad..29b5f544b 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -631,7 +631,7 @@ paths: operationId: getFeatures tags: [system] summary: Get enabled feature flags - description: Returns a dictionary of feature flag names to their enabled state. + description: Returns a dictionary of feature flag names to their enabled state. Cloud deployments may include additional typed fields alongside the boolean flags. responses: "200": description: Feature flags @@ -641,6 +641,43 @@ paths: type: object additionalProperties: type: boolean + properties: + max_upload_size: + type: integer + format: int64 + minimum: 0 + description: "Maximum file upload size in bytes." + free_tier_credits: + type: integer + format: int32 + minimum: 0 + nullable: true + x-runtime: [cloud] + description: "[cloud-only] Credits available to free-tier users. Local ComfyUI returns null." + posthog_api_host: + type: string + format: uri + nullable: true + x-runtime: [cloud] + description: "[cloud-only] PostHog analytics proxy URL for frontend telemetry. Local ComfyUI returns null." + max_concurrent_jobs: + type: integer + format: int32 + minimum: 0 + nullable: true + x-runtime: [cloud] + description: "[cloud-only] Maximum concurrent jobs the authenticated user can run. Local ComfyUI returns null." + workflow_templates_version: + type: string + nullable: true + x-runtime: [cloud] + description: "[cloud-only] Version identifier for the workflow templates bundle. Local ComfyUI returns null." + workflow_templates_source: + type: string + nullable: true + enum: [dynamic_config_override, workflow_templates_version_json] + x-runtime: [cloud] + description: "[cloud-only] How the templates version was resolved. Local ComfyUI returns null." # --------------------------------------------------------------------------- # Node / Object Info @@ -1497,6 +1534,24 @@ paths: type: string enum: [asc, desc] description: Sort direction + - name: job_ids + in: query + schema: + type: string + x-runtime: [cloud] + description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job." + - name: include_public + in: query + schema: + type: boolean + x-runtime: [cloud] + description: "[cloud-only] Include workspace-public assets in addition to the caller's own." + - name: asset_hash + in: query + schema: + type: string + x-runtime: [cloud] + description: "[cloud-only] Filter by exact content hash." responses: "200": description: Asset list @@ -1542,6 +1597,49 @@ paths: type: string format: uuid description: ID of an existing asset to use as the preview image + id: + type: string + format: uuid + nullable: true + x-runtime: [cloud] + description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned." + application/json: + schema: + type: object + x-runtime: [cloud] + description: "[cloud-only] URL-based asset upload. Caller supplies a URL instead of a file body; the server fetches the content." + required: + - url + properties: + url: + type: string + format: uri + description: "[cloud-only] URL of the file to import as an asset" + name: + type: string + description: Display name for the asset + tags: + type: string + description: Comma-separated tags + user_metadata: + type: string + description: JSON-encoded user metadata + hash: + type: string + description: "Blake3 hash of the file content (e.g. blake3:abc123...)" + mime_type: + type: string + description: MIME type of the file (overrides auto-detected type) + preview_id: + type: string + format: uuid + description: ID of an existing asset to use as the preview image + id: + type: string + format: uuid + nullable: true + x-runtime: [cloud] + description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned." responses: "201": description: Asset created @@ -1580,6 +1678,11 @@ paths: user_metadata: type: object additionalProperties: true + mime_type: + type: string + nullable: true + x-runtime: [cloud] + description: "[cloud-only] MIME type of the content, so the type is preserved without re-inspecting content. Ignored by local ComfyUI." responses: "201": description: Asset created from hash @@ -1644,6 +1747,11 @@ paths: type: string format: uuid description: ID of the asset to use as the preview + mime_type: + type: string + nullable: true + x-runtime: [cloud] + description: "[cloud-only] MIME type override when auto-detection was wrong. Ignored by local ComfyUI." responses: "200": description: Asset updated @@ -2004,21 +2112,13 @@ components: format: uuid nullable: true x-runtime: [cloud] - description: | - UUID identifying a hosted-cloud workflow entity to associate with this - job. Local ComfyUI doesn't track workflow entities and returns `null` - (or omits the field). The `x-runtime: [cloud]` extension marks this - as populated only by the hosted-cloud runtime; absence of the tag - means a field is populated by all runtimes. + description: "[cloud-only] Cloud workflow entity ID for tracking and gallery association. Ignored by local ComfyUI." workflow_version_id: type: string format: uuid nullable: true x-runtime: [cloud] - description: | - UUID identifying a hosted-cloud workflow version to associate with - this job. Local ComfyUI returns `null` (or omits the field). See - `workflow_id` above for `x-runtime` semantics. + description: "[cloud-only] Cloud workflow version ID for pinning execution to a specific version. Ignored by local ComfyUI." PromptResponse: type: object From 1655f8089a23232a94b36129286942c33e740168 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Tue, 5 May 2026 17:30:00 -0600 Subject: [PATCH 072/102] Add temporal_downscale_ratio to LatentFormat (#13702) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> Co-authored-by: Alexis Rolland Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com> Co-authored-by: Jedrzej Kosinski --- comfy/latent_formats.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 3dac5be18..60c0dfd7e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -9,6 +9,7 @@ class LatentFormat: latent_rgb_factors_reshape = None taesd_decoder_name = None spacial_downscale_ratio = 8 + temporal_downscale_ratio = 1 def process_in(self, latent): return latent * self.scale_factor @@ -235,6 +236,7 @@ class Flux2(LatentFormat): class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 + temporal_downscale_ratio = 6 def __init__(self): self.scale_factor = 1.0 @@ -278,6 +280,7 @@ class LTXV(LatentFormat): latent_channels = 128 latent_dimensions = 3 spacial_downscale_ratio = 32 + temporal_downscale_ratio = 8 def __init__(self): self.latent_rgb_factors = [ @@ -421,6 +424,7 @@ class LTXAV(LTXV): class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 + temporal_downscale_ratio = 4 scale_factor = 0.476986 latent_rgb_factors = [ [-0.0395, -0.0331, 0.0445], @@ -447,6 +451,7 @@ class HunyuanVideo(LatentFormat): class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 latent_dimensions = 3 + temporal_downscale_ratio = 8 latent_rgb_factors = [ [ 0.1817, 0.2284, 0.2423], @@ -472,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat): class Wan21(LatentFormat): latent_channels = 16 latent_dimensions = 3 + temporal_downscale_ratio = 4 latent_rgb_factors = [ [-0.1299, -0.1692, 0.2932], @@ -734,6 +740,7 @@ class HunyuanVideo15(LatentFormat): latent_channels = 32 latent_dimensions = 3 spacial_downscale_ratio = 16 + temporal_downscale_ratio = 4 scale_factor = 1.03682 taesd_decoder_name = "lighttaehy1_5" @@ -788,6 +795,7 @@ class ZImagePixelSpace(ChromaRadiance): class CogVideoX(LatentFormat): latent_channels = 16 latent_dimensions = 3 + temporal_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.15258426 From e5369c0eec8b1b9c6d1b12a2e4167b46c7fd1c1e Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Tue, 5 May 2026 17:40:53 -0600 Subject: [PATCH 073/102] feat: Context windows - add causal_window_fix to improve blending of context windows (CORE-100) (#13563) * Context windows: add causal_window_fix toggle * Fix slice_cond to correctly handle causal anchor index for temporal offsets --- comfy/context_windows.py | 33 ++++++++++++++++++++++++--- comfy_extras/nodes_context_windows.py | 6 +++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index cb44ee6e8..db57537a2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC): dim = self.dim if dim == 0 and full.shape[dim] == 1: return full - idx = tuple([slice(None)] * dim + [self.index_list]) + indices = self.index_list + anchor_idx = getattr(self, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + indices = [anchor_idx] + list(indices) + idx = tuple([slice(None)] * dim + [indices]) window = full[idx] if retain_index_list: idx = tuple([slice(None)] * dim + [retain_index_list]) @@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) if temporal_offset > 0: - indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + # anchor occupies one of the no-cond positions, so skip one fewer from window.index_list + skip_count = temporal_offset - 1 + else: + skip_count = temporal_offset + + indices = [i - temporal_offset for i in window.index_list[skip_count:]] indices = [i for i in indices if 0 <= i] else: indices = list(window.index_list) @@ -150,7 +161,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, - closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, + causal_window_fix: bool=True): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.causal_window_fix = causal_window_fix self.callbacks = {} @@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC): # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() + # causal_window_fix: prepend a pre-window frame that will be stripped post-forward + anchor_applied = False + if self.causal_window_fix: + anchor_idx = window.index_list[0] - 1 + if 0 <= anchor_idx < x_in.size(self.dim): + window.causal_anchor_index = anchor_idx + anchor_applied = True + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) @@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC): if device is not None: for i in range(len(sub_conds_out)): sub_conds_out[i] = sub_conds_out[i].to(x_in.device) + + # strip causal_window_fix anchor if applied + if anchor_applied: + for i in range(len(sub_conds_out)): + sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) + results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) return results diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 0e43f2e44..fefc56d26 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode): io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode): dim=dim, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, - split_conds_to_windows=split_conds_to_windows + split_conds_to_windows=split_conds_to_windows, + causal_window_fix=causal_window_fix, ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) From c168960a12213df8123a2f234f04a4cf55bbe30d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 5 May 2026 17:00:11 -0700 Subject: [PATCH 074/102] First step of supporting save filenames without trailing _ (#13722) get_save_image_path now properly supports filenames without trailing underscores. This will be the saving behavior when using a mix of save image nodes using the old and the new format. ComfyUI_00001_.png ComfyUI_00002.png ComfyUI_00003.png ComfyUI_00004_.png --- folder_paths.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index 80f4b291a..039f72636 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -432,7 +432,9 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im prefix_len = len(os.path.basename(filename_prefix)) prefix = filename[:prefix_len + 1] try: - digits = int(filename[prefix_len + 1:].split('_')[0]) + remainder = filename[prefix_len + 1:] + base_remainder = remainder.split('.')[0] + digits = int(base_remainder.split('_')[0]) except: digits = 0 return digits, prefix From 160b95f75c9cf60b04fbbf4ec0b8f35f474ffb2a Mon Sep 17 00:00:00 2001 From: iChrist Date: Wed, 6 May 2026 05:47:57 +0300 Subject: [PATCH 075/102] Update language options in nodes_ace.py (#12578) * Update language options in nodes_ace.py Modified it to include all 51 language options ace-step1.5 supports instead of the original 23 comfyui had. * re-arrange list by popularity changed order of the languages to be ordered by popularity en is default unknown is last * Update comfy_extras/nodes_ace.py --- comfy_extras/nodes_ace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 1602add84..affcf3b71 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -42,7 +42,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode): IO.Int.Input("bpm", default=120, min=10, max=300), IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), IO.Combo.Input("timesignature", options=['2', '3', '4', '6']), - IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + IO.Combo.Input("language", options=['ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id', 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no', 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh', 'unknown'], default='en'), IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), From 2b63add0ad975d5f2f0cdc3d4fd8e71ae6553cbf Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 5 May 2026 19:56:09 -0700 Subject: [PATCH 076/102] fix: return millisecond timestamps from get_file_info() (#12996) --- app/user_manager.py | 4 ++-- tests-unit/prompt_server_test/user_manager_test.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/app/user_manager.py b/app/user_manager.py index e18afb71b..0517b3344 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo: return { "path": os.path.relpath(path, relative_to).replace(os.sep, '/'), "size": os.path.getsize(path), - "modified": os.path.getmtime(path), - "created": os.path.getctime(path) + "modified": int(os.path.getmtime(path) * 1000), + "created": int(os.path.getctime(path) * 1000), } diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py index b939d8e68..27118400f 100644 --- a/tests-unit/prompt_server_test/user_manager_test.py +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -69,7 +69,11 @@ async def test_listuserdata_full_info(aiohttp_client, app, tmp_path): assert len(result) == 1 assert result[0]["path"] == "file1.txt" assert "size" in result[0] - assert "modified" in result[0] + assert isinstance(result[0]["modified"], int) + assert isinstance(result[0]["created"], int) + # Verify millisecond magnitude (timestamps after year 2000 in ms are > 946684800000) + assert result[0]["modified"] > 946684800000 + assert result[0]["created"] > 946684800000 async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): From 78b3096bf36ef32378a3d4299473b820211f1601 Mon Sep 17 00:00:00 2001 From: Talmaj Date: Wed, 6 May 2026 04:59:04 +0200 Subject: [PATCH 077/102] Void model - pass 1 & 2 (CORE-38) (#13403) --- comfy/latent_formats.py | 18 + comfy/sd.py | 5 + comfy/supported_models.py | 23 + comfy/text_encoders/cogvideo.py | 42 ++ comfy_extras/nodes_void.py | 483 +++++++++++++++++ comfy_extras/void_noise_warp.py | 494 ++++++++++++++++++ folder_paths.py | 2 + .../optical_flow/put_optical_flow_models_here | 0 nodes.py | 5 +- 9 files changed, 1070 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/nodes_void.py create mode 100644 comfy_extras/void_noise_warp.py create mode 100644 models/optical_flow/put_optical_flow_models_here diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 60c0dfd7e..91bebed3d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -793,9 +793,27 @@ class ZImagePixelSpace(ChromaRadiance): pass class CogVideoX(LatentFormat): + """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). + + scale_factor matches the vae/config.json scaling_factor for the 2b variant. + The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*) + use a different value; see CogVideoX1_5 below. + """ latent_channels = 16 latent_dimensions = 3 temporal_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.15258426 + + +class CogVideoX1_5(CogVideoX): + """Latent format for 5b-class CogVideoX checkpoints. + + Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun + V1.5-5b family (including VOID inpainting). All of these have + scaling_factor=0.7 in their vae/config.json. Auto-selected in + supported_models.CogVideoX_T2V based on transformer hidden dim. + """ + def __init__(self): + self.scale_factor = 0.7 diff --git a/comfy/sd.py b/comfy/sd.py index 9fce0e7d0..749bdd710 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -66,6 +66,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 +import comfy.text_encoders.cogvideo import comfy.model_patcher import comfy.lora @@ -1224,6 +1225,7 @@ class CLIPType(Enum): NEWBIE = 24 FLUX2 = 25 LONGCAT_IMAGE = 26 + COGVIDEOX = 27 @@ -1428,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.COGVIDEOX: + clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index dff40461f..6a9613602 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1872,6 +1872,14 @@ class CogVideoX_T2V(supported_models_base.BASE): vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] + def __init__(self, unet_config): + # 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426. + # 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and + # Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json. + if unet_config.get("num_attention_heads", 0) >= 48: + self.latent_format = latent_formats.CogVideoX1_5 + super().__init__(unet_config) + def get_model(self, state_dict, prefix="", device=None): # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE if self.unet_config.get("patch_size_t") is not None: @@ -1898,6 +1906,20 @@ class CogVideoX_I2V(CogVideoX_T2V): out = model_base.CogVideoX(self, image_to_video=True, device=device) return out +class CogVideoX_Inpaint(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + models = [ LotusD, @@ -1978,6 +2000,7 @@ models = [ ErnieImage, SAM3, SAM31, + CogVideoX_Inpaint, CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py index f1e8e3f5d..b97310709 100644 --- a/comfy/text_encoders/cogvideo.py +++ b/comfy/text_encoders/cogvideo.py @@ -1,6 +1,48 @@ import comfy.text_encoders.sd3_clip +from comfy import sd1_clip class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + """Inner T5 tokenizer for CogVideoX. + + CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3). + Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with + the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below. + """ def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) + + +class CogVideoXTokenizer(sd1_clip.SD1Tokenizer): + """Outer tokenizer wrapper for CLIPLoader (type="cogvideox").""" + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer) + + +class CogVideoXT5XXL(sd1_clip.SD1ClipModel): + """Outer T5XXL model wrapper for CLIPLoader (type="cogvideox"). + + Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__ + (which reads self.dtypes) works correctly. The inner model is the standard + sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX). + """ + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", + clip_model=comfy.text_encoders.sd3_clip.T5XXLModel, + model_options=model_options) + + +def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None): + """Factory that returns a CogVideoXT5XXL class configured with the detected + T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts. + """ + class CogVideoXTEModel_(CogVideoXT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5_quantization_metadata is not None: + model_options = model_options.copy() + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if dtype_t5 is not None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return CogVideoXTEModel_ diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py new file mode 100644 index 000000000..e7a8f3757 --- /dev/null +++ b/comfy_extras/nodes_void.py @@ -0,0 +1,483 @@ +import logging + +import torch + +import comfy +import comfy.model_management +import comfy.model_patcher +import comfy.samplers +import comfy.utils +import folder_paths +import node_helpers +import nodes +from comfy.utils import model_trange as trange +from comfy_api.latest import ComfyExtension, io +from torchvision.models.optical_flow import raft_large +from typing_extensions import override + + +from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video + +OpticalFlow = io.Custom("OPTICAL_FLOW") + +TEMPORAL_COMPRESSION = 4 +PATCH_SIZE_T = 2 + + +def _valid_void_length(length: int) -> int: + """Round ``length`` down to a value that produces an even latent_t. + + VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent + must have an even temporal dimension. If latent_t is odd, the transformer + pad_to_patch_size circular-wraps an extra latent frame onto the end; after + the post-transformer crop the last real latent frame has been influenced + by the wrapped phantom frame, producing visible jitter and "disappearing" + subjects near the end of the decoded video. Rounding down fixes this. + """ + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + if latent_t % PATCH_SIZE_T == 0: + return length + # Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert + # the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame + # so we never return a non-positive length. + target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T) + return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1 + + +class OpticalFlowLoader(io.ComfyNode): + """Load an optical flow model from ``models/optical_flow/``. + + Only torchvision's RAFT-large format is recognized today (the model used + by VOIDWarpedNoise). The checkpoint must be placed under + ``models/optical_flow/`` — ComfyUI never downloads optical-flow weights + at runtime. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="OpticalFlowLoader", + display_name="Load Optical Flow Model", + category="loaders", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("optical_flow"), + tooltip=( + "Optical flow model to load. Files must be placed in the " + "'optical_flow' folder. Today only torchvision's " + "raft_large.pth is supported." + ), + ), + ], + outputs=[ + OpticalFlow.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + + model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + has_raft_keys = ( + any(k.startswith("feature_encoder.") for k in sd) + and any(k.startswith("context_encoder.") for k in sd) + and any(k.startswith("update_block.") for k in sd) + ) + if not has_raft_keys: + raise ValueError( + "Unrecognized optical flow model format: expected a torchvision " + "RAFT-large state dict with 'feature_encoder.', 'context_encoder.' " + "and 'update_block.' prefixes." + ) + + model = raft_large(weights=None, progress=False) + model.load_state_dict(sd) + model.eval().to(torch.float32) + + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + +class VOIDQuadmaskPreprocess(io.ComfyNode): + """Preprocess a quadmask video for VOID inpainting. + + Quantizes mask values to four semantic levels, inverts, and normalizes: + 0 -> primary object to remove + 63 -> overlap of primary + affected + 127 -> affected region (interactions) + 255 -> background (keep) + + After inversion and normalization, the output mask has values in [0, 1] + with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep). + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDQuadmaskPreprocess", + category="mask/video", + inputs=[ + io.Mask.Input("mask"), + io.Int.Input("dilate_width", default=0, min=0, max=50, step=1, + tooltip="Dilation radius for the primary mask region (0 = no dilation)"), + ], + outputs=[ + io.Mask.Output(display_name="quadmask"), + ], + ) + + @classmethod + def execute(cls, mask, dilate_width=0) -> io.NodeOutput: + m = mask.clone() + + if m.max() <= 1.0: + m = m * 255.0 + + if dilate_width > 0 and m.ndim >= 3: + binary = (m < 128).float() + kernel_size = dilate_width * 2 + 1 + if binary.ndim == 3: + binary = binary.unsqueeze(1) + dilated = torch.nn.functional.max_pool2d( + binary, kernel_size=kernel_size, stride=1, padding=dilate_width + ) + if dilated.ndim == 4: + dilated = dilated.squeeze(1) + m = torch.where(dilated > 0.5, torch.zeros_like(m), m) + + m = torch.where(m <= 31, torch.zeros_like(m), m) + m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m) + m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m) + m = torch.where(m > 191, torch.full_like(m, 255), m) + + m = (255.0 - m) / 255.0 + + return io.NodeOutput(m) + + +class VOIDInpaintConditioning(io.ComfyNode): + """Build VOID inpainting conditioning for CogVideoX. + + Encodes the processed quadmask and masked source video through the VAE, + producing a 32-channel concat conditioning (16ch mask + 16ch masked video) + that gets concatenated with the 16ch noise latent by the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDInpaintConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"), + io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"), + io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1, + tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 " + "(patch_size_t=2), latent_t must be even — lengths that " + "produce odd latent_t are rounded down (e.g. 49 → 45)."), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, video, quadmask, + width, height, length, batch_size) -> io.NodeOutput: + + adjusted_length = _valid_void_length(length) + if adjusted_length != length: + logging.warning( + "VOIDInpaintConditioning: rounding length %d down to %d so that " + "latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). " + "Using odd latent_t causes the last frame to be corrupted by " + "circular padding.", length, adjusted_length, + ) + length = adjusted_length + + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + latent_h = height // 8 + latent_w = width // 8 + + vid = video[:length] + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + + qm = quadmask[:length] + if qm.ndim == 3: + qm = qm.unsqueeze(-1) + qm = comfy.utils.common_upscale( + qm.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + if qm.ndim == 4 and qm.shape[-1] == 1: + qm = qm.squeeze(-1) + + mask_condition = qm + if mask_condition.ndim == 3: + mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3) + else: + mask_condition_3ch = mask_condition + + inverted_mask_3ch = 1.0 - mask_condition_3ch + masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch) + + mask_latents = vae.encode(inverted_mask_3ch) + masked_video_latents = vae.encode(masked_video) + + def _match_temporal(lat, target_t): + if lat.shape[2] > target_t: + return lat[:, :, :target_t] + elif lat.shape[2] < target_t: + pad = target_t - lat.shape[2] + return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) + return lat + + mask_latents = _match_temporal(mask_latents, latent_t) + masked_video_latents = _match_temporal(masked_video_latents, latent_t) + + inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1) + + # No explicit scaling needed here: the model's CogVideoX.concat_cond() + # applies process_latent_in (×latent_format.scale_factor) to each 16-ch + # block of the stored conditioning. For 5b-class checkpoints (incl. the + # VOID/CogVideoX-Fun-V1.5 inpainting model) that scale_factor is auto- + # selected as 0.7 in supported_models.CogVideoX_T2V, which matches the + # diffusers vae/config.json scaling_factor VOID was trained with. + + positive = node_helpers.conditioning_set_values( + positive, {"concat_latent_image": inpaint_latents} + ) + negative = node_helpers.conditioning_set_values( + negative, {"concat_latent_image": inpaint_latents} + ) + + noise_latent = torch.zeros( + [batch_size, 16, latent_t, latent_h, latent_w], + device=comfy.model_management.intermediate_device() + ) + + return io.NodeOutput(positive, negative, {"samples": noise_latent}) + + +class VOIDWarpedNoise(io.ComfyNode): + """Generate optical-flow warped noise for VOID Pass 2 refinement. + + Takes the Pass 1 output video and produces temporally-correlated noise + by warping Gaussian noise along optical flow vectors. This noise is used + as the initial latent for Pass 2, resulting in better temporal consistency. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDWarpedNoise", + category="latent/video", + inputs=[ + OpticalFlow.Input( + "optical_flow", + tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).", + ), + io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"), + io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1, + tooltip="Number of pixel frames. Rounded down to make latent_t " + "even (patch_size_t=2 requirement), e.g. 49 → 45."), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Latent.Output(display_name="warped_noise"), + ], + ) + + @classmethod + def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput: + + adjusted_length = _valid_void_length(length) + if adjusted_length != length: + logging.warning( + "VOIDWarpedNoise: rounding length %d down to %d so that " + "latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).", + length, adjusted_length, + ) + length = adjusted_length + + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + latent_h = height // 8 + latent_w = width // 8 + + # RAFT + noise warp is real compute, not an "intermediate" buffer, so + # we want the actual torch device (CUDA/MPS). The final latent is + # moved back to intermediate_device() before returning to match the + # rest of the ComfyUI pipeline. + device = comfy.model_management.get_torch_device() + + comfy.model_management.load_model_gpu(optical_flow) + raft = RaftOpticalFlow(optical_flow.model, device=device) + + vid = video[:length].to(device) + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8) + + FRAME = 2**-1 + FLOW = 2**3 + LATENT_SCALE = 8 + + warped = get_noise_from_video( + vid_uint8, + raft, + noise_channels=16, + resize_frames=FRAME, + resize_flow=FLOW, + downscale_factor=round(FRAME * FLOW) * LATENT_SCALE, + device=device, + ) + + if warped.shape[0] != latent_t: + indices = torch.linspace(0, warped.shape[0] - 1, latent_t, + device=device).long() + warped = warped[indices] + + if warped.shape[1] != latent_h or warped.shape[2] != latent_w: + # (T, H, W, C) → (T, C, H, W) → bilinear resize → back + warped = warped.permute(0, 3, 1, 2) + warped = torch.nn.functional.interpolate( + warped, size=(latent_h, latent_w), + mode="bilinear", align_corners=False, + ) + warped = warped.permute(0, 2, 3, 1) + + # (T, H, W, C) → (B, C, T, H, W) + warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0) + if batch_size > 1: + warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1) + + warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": warped_tensor}) + + +class Noise_FromLatent: + """Wraps a pre-computed LATENT tensor as a NOISE source.""" + def __init__(self, latent_dict): + self.seed = 0 + self._samples = latent_dict["samples"] + + def generate_noise(self, input_latent): + return self._samples.clone().cpu() + + +class VOIDWarpedNoiseSource(io.ComfyNode): + """Convert a LATENT (e.g. from VOIDWarpedNoise) into a NOISE source + for use with SamplerCustomAdvanced.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDWarpedNoiseSource", + category="sampling/custom_sampling/noise", + inputs=[ + io.Latent.Input("warped_noise", + tooltip="Warped noise latent from VOIDWarpedNoise"), + ], + outputs=[io.Noise.Output()], + ) + + @classmethod + def execute(cls, warped_noise) -> io.NodeOutput: + return io.NodeOutput(Noise_FromLatent(warped_noise)) + + +class VOID_DDIM(comfy.samplers.Sampler): + """DDIM sampler for VOID inpainting models. + + VOID was trained with the diffusers CogVideoXDDIMScheduler which operates in + alpha-space (input std ≈ 1). The standard KSampler applies noise_scaling that + multiplies by sqrt(1+sigma^2) ≈ 4500x, which is incompatible with VOID's + training. This sampler skips noise_scaling and implements the DDIM update rule + directly using sigma-to-alpha conversion. + """ + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + x = noise.to(torch.float32) + model_options = extra_args.get("model_options", {}) + seed = extra_args.get("seed", None) + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable_pbar): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + denoised = model_wrap(x, sigma * s_in, model_options=model_options, seed=seed) + + if callback is not None: + callback(i, denoised, x, len(sigmas) - 1) + + if sigma_next == 0: + x = denoised + else: + alpha_t = 1.0 / (1.0 + sigma ** 2) + alpha_prev = 1.0 / (1.0 + sigma_next ** 2) + + pred_eps = (x - (alpha_t ** 0.5) * denoised) / (1.0 - alpha_t) ** 0.5 + x = (alpha_prev ** 0.5) * denoised + (1.0 - alpha_prev) ** 0.5 * pred_eps + + return x + + +class VOIDSampler(io.ComfyNode): + """VOID DDIM sampler for use with SamplerCustom / SamplerCustomAdvanced. + + Required for VOID inpainting models. Implements the same DDIM loop that VOID + was trained with (diffusers CogVideoXDDIMScheduler), without the noise_scaling + that the standard KSampler applies. Use with RandomNoise or VOIDWarpedNoiseSource. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDSampler", + category="sampling/custom_sampling/samplers", + inputs=[], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(VOID_DDIM()) + + get_sampler = execute + + +class VOIDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + OpticalFlowLoader, + VOIDQuadmaskPreprocess, + VOIDInpaintConditioning, + VOIDWarpedNoise, + VOIDWarpedNoiseSource, + VOIDSampler, + ] + + +async def comfy_entrypoint() -> VOIDExtension: + return VOIDExtension() diff --git a/comfy_extras/void_noise_warp.py b/comfy_extras/void_noise_warp.py new file mode 100644 index 000000000..fcc9a5f8b --- /dev/null +++ b/comfy_extras/void_noise_warp.py @@ -0,0 +1,494 @@ +""" +Optical-flow-warped noise for VOID Pass 2 refinement. + +Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert): + https://github.com/RyannDaGreat/CommonSource + - noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video) + - raft.py (RaftOpticalFlow) + +Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually +uses (torch THWC uint8 input, no background removal, no visualization, no disk +I/O, default warp/noise params) have been inlined. External ``rp`` utilities +have been replaced with equivalents from torch.nn.functional / einops. The +RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in +``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this +module never downloads weights at runtime. +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange + +import comfy.model_management + + +# --------------------------------------------------------------------------- +# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives) +# --------------------------------------------------------------------------- + +def _torch_resize_chw(image, size, interp, copy=True): + """Resize a CHW tensor. + + ``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one + of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and + the requested size matches the input, returns the input tensor as is + (faster but callers must not mutate the result). + """ + if image.ndim != 3: + raise ValueError( + f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + _, in_h, in_w = image.shape + if isinstance(size, (int, float)) and not isinstance(size, bool): + new_h = max(1, int(in_h * size)) + new_w = max(1, int(in_w * size)) + else: + new_h, new_w = size + + if (new_h, new_w) == (in_h, in_w): + return image.clone() if copy else image + + kwargs = {} + if interp in ("bilinear", "bicubic"): + kwargs["align_corners"] = False + out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0] + return out + + +def _torch_remap_relative(image, dx, dy, interp="bilinear"): + """Relative remap of a CHW image via ``F.grid_sample``. + + Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` + for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. + """ + if image.ndim != 3: + raise ValueError( + f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + if dx.shape != dy.shape: + raise ValueError( + f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) + _, h, w = image.shape + + x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) + y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None] + + x_norm = (x_abs / (w - 1)) * 2 - 1 + y_norm = (y_abs / (h - 1)) * 2 - 1 + + grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype) + out = F.grid_sample( + image[None], grid, mode=interp, align_corners=True, padding_mode="zeros" + )[0] + return out + + +def _torch_scatter_add_relative(image, dx, dy): + """Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets. + + Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, + interp='floor')``. Out-of-bounds targets are dropped. + """ + if image.ndim != 3: + raise ValueError( + f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + in_c, in_h, in_w = image.shape + if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w): + raise ValueError( + f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), " + f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}" + ) + + x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) + y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] + + valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1) + indices = (y * in_w + x).reshape(-1)[valid] + + flat_image = rearrange(image, "c h w -> (h w) c")[valid] + out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device) + out.index_add_(0, indices, flat_image) + return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w) + + +# --------------------------------------------------------------------------- +# Noise warping primitives (ported from noise_warp.py) +# --------------------------------------------------------------------------- + +def unique_pixels(image): + """Find unique pixel values in a CHW tensor. + + Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where + ``index_matrix[i, j]`` is the index of the unique color at that pixel. + """ + _, h, w = image.shape + flat = rearrange(image, "c h w -> (h w) c") + unique_colors, inverse_indices, counts = torch.unique( + flat, dim=0, return_inverse=True, return_counts=True, sorted=False, + ) + index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w) + return unique_colors, counts, index_matrix + + +def sum_indexed_values(image, index_matrix): + """For each unique index, sum the CHW image values at its pixels.""" + _, h, w = image.shape + u = int(index_matrix.max().item()) + 1 + flat = rearrange(image, "c h w -> (h w) c") + out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device) + out.index_add_(0, index_matrix.view(-1), flat) + return out + + +def indexed_to_image(index_matrix, unique_colors): + """Build a CHW image from an index matrix and a (U, C) color table.""" + h, w = index_matrix.shape + flat = unique_colors[index_matrix.view(-1)] + return rearrange(flat, "(h w) c -> c h w", h=h, w=w) + + +def regaussianize(noise): + """Variance-preserving re-sampling of a CHW noise tensor. + + Wherever the noise contains groups of identical pixel values (e.g. after + a nearest-neighbor warp that duplicated source pixels), adds zero-mean + foreign noise within each group and scales by ``1/sqrt(count)`` so the + output is unit-variance gaussian again. + """ + _, hs, ws = noise.shape + _, counts, index_matrix = unique_pixels(noise[:1]) + + foreign_noise = torch.randn_like(noise) + summed = sum_indexed_values(foreign_noise, index_matrix) + meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1")) + zeroed_foreign = foreign_noise - meaned + + counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1")) + + output = noise / counts_image ** 0.5 + zeroed_foreign + return output, counts_image + + +def xy_meshgrid_like_image(image): + """Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``.""" + _, h, w = image.shape + y, x = torch.meshgrid( + torch.arange(h, device=image.device, dtype=image.dtype), + torch.arange(w, device=image.device, dtype=image.dtype), + indexing="ij", + ) + return torch.stack([x, y]) + + +def noise_to_state(noise): + """Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise].""" + zeros = torch.zeros_like(noise[:1]) + ones = torch.ones_like(noise[:1]) + return torch.cat([zeros, zeros, ones, noise]) + + +def state_to_noise(state): + """Unpack the noise channels from a state tensor.""" + return state[3:] + + +def warp_state(state, flow): + """Warp a noise-warper state tensor along the given optical flow. + + ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). + ``flow`` has shape ``(2, h, w)`` (= dx, dy). + """ + if flow.device != state.device: + raise ValueError( + f"warp_state: flow and state must be on the same device, " + f"got flow={flow.device} state={state.device}" + ) + if state.ndim != 3: + raise ValueError( + f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}" + ) + xyoc, h, w = state.shape + if flow.shape != (2, h, w): + raise ValueError( + f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}" + ) + device = state.device + + x_ch, y_ch = 0, 1 + xy = 2 # state[:xy] = [dx, dy] + xyw = 3 # state[:xyw] = [dx, dy, ω] + w_ch = 2 # state[w_ch] = ω + c = xyoc - xyw + oc = xyoc - xy + if c <= 0: + raise ValueError( + f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)" + ) + if not (state[w_ch] > 0).all(): + raise ValueError("warp_state: all weights in state[2] must be > 0") + + grid = xy_meshgrid_like_image(state) + + init = torch.empty_like(state) + init[:xy] = 0 + init[w_ch] = 1 + init[-c:] = 0 + + # --- Expansion branch: nearest-neighbor remap with negated flow --- + pre_expand = torch.empty_like(state) + pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest") + pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest") + pre_expand[w_ch][pre_expand[w_ch] == 0] = 1 + + # --- Shrink branch: scatter-add state into new positions --- + pre_shrink = state.clone() + pre_shrink[:xy] += flow + + pos = (grid + pre_shrink[:xy]).round() + in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h) + pre_shrink = torch.where(~in_bounds[None], init, pre_shrink) + + scat_xy = pre_shrink[:xy].round() + pre_shrink[:xy] -= scat_xy + pre_shrink[:xy] = 0 # xy_mode='none' in upstream + + def scat(tensor): + return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1]) + + # rp.torch_scatter_add_image on a bool tensor errors on modern torch; + # scatter-sum a float ones tensor and threshold to get the mask instead. + shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0 + + # Drop expansion samples at positions that will be filled by shrink. + pre_expand = torch.where(shrink_mask, init, pre_expand) + + # Regaussianize both branches together so duplicated-source groups are + # counted globally, then split back apart. + concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width + concat[-c:], counts_image = regaussianize(concat[-c:]) + concat[w_ch] = concat[w_ch] / counts_image[0] + concat[w_ch] = concat[w_ch].nan_to_num() + pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2) + + shrink = torch.empty_like(pre_shrink) + shrink[w_ch] = scat(pre_shrink[w_ch][None])[0] + shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None] + shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat( + pre_shrink[w_ch][None] ** 2 + ).sqrt() + + output = torch.where(shrink_mask, shrink, expand) + output[w_ch] = output[w_ch] / output[w_ch].mean() + output[w_ch] += 1e-5 + output[w_ch] **= 0.9999 + return output + + +class NoiseWarper: + """Maintain a warpable noise state and emit gaussian noise per frame. + + Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper: + ``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and + ``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults. + """ + + def __init__(self, c, h, w, device, dtype=torch.float32): + if c <= 0 or h <= 0 or w <= 0: + raise ValueError( + f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}" + ) + self.c = c + self.h = h + self.w = w + self.device = device + self.dtype = dtype + + noise = torch.randn(c, h, w, dtype=dtype, device=device) + self._state = noise_to_state(noise) + + @property + def noise(self): + # With scale_factor=1 the "downsample to respect weights" step is a + # size-preserving no-op; the weight-variance correction math still + # runs to stay faithful to upstream. + n = state_to_noise(self._state) + weights = self._state[2:3] + return n * weights / (weights ** 2).sqrt() + + def __call__(self, dx, dy): + if dx.shape != dy.shape: + raise ValueError( + f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) + flow = torch.stack([dx, dy]).to(self.device, self.dtype) + _, oflowh, ofloww = flow.shape + + flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True) + flowh, floww = flow.shape[-2:] + + # Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww + # (channel-order appears swapped but harmless when H and W are scaled + # by the same factor, which is always the case for our callers). + flow[0] *= flowh / oflowh + flow[1] *= floww / ofloww + + self._state = warp_state(self._state, flow) + return self + + +# --------------------------------------------------------------------------- +# RAFT optical flow wrapper (ported from raft.py) +# --------------------------------------------------------------------------- + +class RaftOpticalFlow: + """RAFT-large wrapper around a pre-loaded torchvision model. + + ``model`` must be the ``torchvision.models.optical_flow.raft_large`` module + with its weights already populated; this class is load-agnostic so the + caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in + ``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow. + """ + + def __init__(self, model, device=None): + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + model = model.to(device) + model.eval() + self.device = device + self.model = model + + def _preprocess(self, image_chw): + image = image_chw.to(self.device, torch.float32) + _, h, w = image.shape + new_h = (h // 8) * 8 + new_w = (w // 8) * 8 + image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False) + image = image * 2 - 1 + return image[None] + + def __call__(self, from_image, to_image): + """``from_image``, ``to_image``: CHW float tensors in [0, 1].""" + if from_image.shape != to_image.shape: + raise ValueError( + f"RaftOpticalFlow: from_image and to_image must match, " + f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}" + ) + _, h, w = from_image.shape + with torch.no_grad(): + img1 = self._preprocess(from_image) + img2 = self._preprocess(to_image) + list_of_flows = self.model(img1, img2) + flow = list_of_flows[-1][0] # (2, new_h, new_w) + if flow.shape[-2:] != (h, w): + flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False) + return flow + + +# --------------------------------------------------------------------------- +# Narrow entry point used by VOIDWarpedNoise +# --------------------------------------------------------------------------- + +def get_noise_from_video( + video_frames: torch.Tensor, + raft: RaftOpticalFlow, + *, + noise_channels: int = 16, + resize_frames: float = 0.5, + resize_flow: int = 8, + downscale_factor: int = 32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Produce optical-flow-warped gaussian noise from a video. + + Args: + video_frames: ``(T, H, W, 3)`` uint8 torch tensor. + raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``). + noise_channels: Channels in the output noise. + resize_frames: Pre-RAFT frame scale factor. + resize_flow: Post-flow up-scale factor applied to the optical flow; + the internal noise state is allocated at + ``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``. + downscale_factor: Area-pool factor applied to the noise before return; + should evenly divide the internal noise resolution. + device: Target device. Defaults to ``comfy.model_management.get_torch_device()``. + + Returns: + ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. + """ + if not isinstance(resize_flow, int) or resize_flow < 1: + raise ValueError( + f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}" + ) + if video_frames.ndim != 4 or video_frames.shape[-1] != 3: + raise ValueError( + "get_noise_from_video: video_frames must have shape (T, H, W, 3), " + f"got {tuple(video_frames.shape)}" + ) + if video_frames.dtype != torch.uint8: + raise TypeError( + "get_noise_from_video: video_frames must be uint8 in [0, 255], " + f"got dtype {video_frames.dtype}" + ) + + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + if device.type == "cpu": + logging.warning( + "VOIDWarpedNoise: running get_noise_from_video on CPU; this will be " + "slow (minutes for ~45 frames). Use CUDA for interactive use." + ) + + T = video_frames.shape[0] + frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0 + if resize_frames != 1.0: + new_h = max(1, int(frames.shape[2] * resize_frames)) + new_w = max(1, int(frames.shape[3] * resize_frames)) + frames = F.interpolate(frames, size=(new_h, new_w), mode="area") + + _, _, H, W = frames.shape + internal_h = resize_flow * H + internal_w = resize_flow * W + if internal_h % downscale_factor or internal_w % downscale_factor: + logging.warning( + "VOIDWarpedNoise: internal noise size %dx%d is not divisible by " + "downscale_factor %d; output noise may have artifacts.", + internal_h, internal_w, downscale_factor, + ) + + with torch.no_grad(): + warper = NoiseWarper( + c=noise_channels, h=internal_h, w=internal_w, device=device, + ) + down_h = warper.h // downscale_factor + down_w = warper.w // downscale_factor + output = torch.empty( + (T, down_h, down_w, noise_channels), dtype=torch.float32, device=device, + ) + + def downscale(noise_chw): + # Area-pool to 1/downscale_factor then multiply by downscale_factor + # to adjust std (sqrt of pool area == downscale_factor for a + # square pool). + down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False) + return down * downscale_factor + + output[0] = downscale(warper.noise).permute(1, 2, 0) + + prev = frames[0] + for i in range(1, T): + curr = frames[i] + flow = raft(prev, curr).to(device) + warper(flow[0], flow[1]) + output[i] = downscale(warper.noise).permute(1, 2, 0) + prev = curr + + return output diff --git a/folder_paths.py b/folder_paths.py index 039f72636..98d3b1880 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) +folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/optical_flow/put_optical_flow_models_here b/models/optical_flow/put_optical_flow_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index cf61d9df0..ad0cbc675 100644 --- a/nodes.py +++ b/nodes.py @@ -958,7 +958,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -968,7 +968,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2430,6 +2430,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", + "nodes_void.py", ] import_failed = [] From 9c34f5f36a3815af7d21d8b42b0a5776b7406685 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 6 May 2026 14:22:48 +0900 Subject: [PATCH 078/102] Bump comfyui-frontend-package to 1.43.17 (#13723) Co-authored-by: github-actions[bot] Co-authored-by: Alexander Brown --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e9415f2fd..e7aa92c31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.42.15 +comfyui-frontend-package==1.43.17 comfyui-workflow-templates==0.9.69 comfyui-embedded-docs==0.4.4 torch From 6bcd8b96ab4650db6e834dcbb54357ebf72edfe6 Mon Sep 17 00:00:00 2001 From: guill Date: Wed, 6 May 2026 10:08:35 -0700 Subject: [PATCH 079/102] Revert "Fix Content-Disposition header missing 'attachment;' prefix (#13093)" (#13733) This reverts commit ea6880b04b88629b9dd07774298bdffea6923f9b. --- server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index 0e85635d3..2f3b438bb 100644 --- a/server.py +++ b/server.py @@ -560,7 +560,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type=f'image/{image_format}', - headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) + headers={"Content-Disposition": f"filename=\"{filename}\""}) if 'channel' not in request.rel_url.query: channel = 'rgba' @@ -580,7 +580,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) + headers={"Content-Disposition": f"filename=\"{filename}\""}) elif channel == 'a': with Image.open(file) as img: @@ -597,7 +597,7 @@ class PromptServer(): alpha_buffer.seek(0) return web.Response(body=alpha_buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) + headers={"Content-Disposition": f"filename=\"{filename}\""}) else: # Use the content type from asset resolution if available, # otherwise guess from the filename. @@ -614,7 +614,7 @@ class PromptServer(): return web.FileResponse( file, headers={ - "Content-Disposition": f"attachment; filename=\"{filename}\"", + "Content-Disposition": f"filename=\"{filename}\"", "Content-Type": content_type } ) From cd8c7a2306be98bf93cd6632384a675afe750a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 7 May 2026 05:41:13 +0300 Subject: [PATCH 080/102] Throttle dynamic VRAM prepare logging (#13704) --- comfy/model_patcher.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7d2d6883f..33bdedfb1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -26,6 +26,7 @@ import uuid from typing import Callable, Optional import torch +import tqdm import comfy.float import comfy.hooks @@ -1651,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher): self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") + log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory) + in_loop = bool(getattr(tqdm.tqdm, "_instances", None)) + level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO + self._last_prepare_log_key = log_key + logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid From e35348aa53563cabdcd9e5f67d0cb77b5259c903 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 6 May 2026 19:51:01 -0700 Subject: [PATCH 081/102] Add .comfy_environment to portable. (#13746) --- .github/workflows/stable-release.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index f501b7b31..bc64ed74d 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -145,6 +145,8 @@ jobs: cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ + echo 'local-portable' > ComfyUI/.comfy_environment + cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable From 1b25f1289e6f48081b727083425791876ed0f39b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 7 May 2026 09:45:59 +0300 Subject: [PATCH 082/102] [Partner Nodes] add grok-imagine-image-quality model (#13725) * feat(api-nodes): add grok-imagine-image-quality model Signed-off-by: bigcat88 * fixed price badges Signed-off-by: bigcat88 * fix: adjust price badges Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_grok.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index f42d84616..dd5d7e249 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -54,7 +54,12 @@ class GrokImageNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + options=[ + "grok-imagine-image-quality", + "grok-imagine-image-pro", + "grok-imagine-image", + "grok-imagine-image-beta", + ], ), IO.String.Input( "prompt", @@ -111,10 +116,12 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]), expr=""" ( - $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + $rate := widgets.model = "grok-imagine-image-quality" + ? (widgets.resolution = "1k" ? 0.05 : 0.07) + : ($contains(widgets.model, "pro") ? 0.07 : 0.02); {"type":"usd","usd": $rate * widgets.number_of_images} ) """, @@ -167,7 +174,12 @@ class GrokImageEditNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + options=[ + "grok-imagine-image-quality", + "grok-imagine-image-pro", + "grok-imagine-image", + "grok-imagine-image-beta", + ], ), IO.Image.Input("image", display_name="images"), IO.String.Input( @@ -228,11 +240,19 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]), expr=""" ( - $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; - {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + $isQualityModel := widgets.model = "grok-imagine-image-quality"; + $isPro := $contains(widgets.model, "pro"); + $rate := $isQualityModel + ? (widgets.resolution = "1k" ? 0.05 : 0.07) + : ($isPro ? 0.07 : 0.02); + $base := $isQualityModel ? 0.01 : 0.002; + $output := $rate * widgets.number_of_images; + $isPro + ? {"type":"usd","usd": $base + $output} + : {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output} ) """, ), From 25757a53c93281e8e2462ced8795373f09e675bf Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Thu, 7 May 2026 16:28:18 +0900 Subject: [PATCH 083/102] chore: update workflow templates to v0.9.72 (#13732) Co-authored-by: Jedrzej Kosinski --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e7aa92c31..5c7ff76be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.43.17 -comfyui-workflow-templates==0.9.69 +comfyui-workflow-templates==0.9.72 comfyui-embedded-docs==0.4.4 torch torchsde From c945a433ae09423f7a2a6e9631538e55b9375f78 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 7 May 2026 21:55:09 +0300 Subject: [PATCH 084/102] fix(api-nodes): fixed price badge for Kling V3 model in the Motion Control node (#13790) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_kling.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index efd58fac3..7586f1816 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]), expr=""" ( - $prices := {"std": 0.07, "pro": 0.112}; - {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + $prices := { + "kling-v3": {"std": 0.126, "pro": 0.168}, + "kling-v2-6": {"std": 0.07, "pro": 0.112} + }; + $modelPrices := $lookup($prices, widgets.model); + {"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}} ) """, ), From c011fb520c79b9dfbe7f885d613771774f746eef Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 7 May 2026 22:19:44 +0300 Subject: [PATCH 085/102] [Partner Nodes] new NanoBanana2 node with DynamicCombo/Autogrow (#13753) * feat(api-nodes): new NanoBanana2 node with DynamicCombo/Autogrow Signed-off-by: bigcat88 * feat: improved status text on uploading Signed-off-by: bigcat88 * feat: improved status text on uploading (2) Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_gemini.py | 242 +++++++++++++++++++++++++++++--- 1 file changed, 222 insertions(+), 20 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 2b77a022e..d18c958a8 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum): async def create_image_parts( cls: type[IO.ComfyNode], - images: Input.Image, + images: Input.Image | list[Input.Image], image_limit: int = 0, ) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] if image_limit < 0: raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") - total_images = get_number_of_images(images) + + # Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all. + images_list: list[Input.Image] = images if isinstance(images, list) else [images] + total_images = sum(get_number_of_images(img) for img in images_list) if total_images <= 0: raise ValueError("No images provided to create_image_parts; at least one image is required.") @@ -98,10 +101,18 @@ async def create_image_parts( # Number of images we'll send as URLs (fileData) num_url_images = min(effective_max, 10) # Vertex API max number of image links + upload_kwargs: dict = {"wait_label": "Uploading reference images"} + if effective_max > num_url_images: + # Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label. + upload_kwargs = { + "wait_label": f"Uploading reference images ({num_url_images}+)", + "show_batch_index": False, + } reference_images_urls = await upload_images_to_comfyapi( cls, - images, + images_list, max_images=num_url_images, + **upload_kwargs, ) for reference_image_url in reference_images_urls: image_parts.append( @@ -112,15 +123,22 @@ async def create_image_parts( ) ) ) - for idx in range(num_url_images, effective_max): - image_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.image_png, - data=tensor_to_base64_string(images[idx]), + if effective_max > num_url_images: + flat: list[torch.Tensor] = [] + for tensor in images_list: + if len(tensor.shape) == 4: + flat.extend(tensor[i] for i in range(tensor.shape[0])) + else: + flat.append(tensor) + for idx in range(num_url_images, effective_max): + image_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.image_png, + data=tensor_to_base64_string(flat[idx]), + ) ) ) - ) return image_parts @@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode): "9:16", "16:9", "21:9", - # "1:4", - # "4:1", - # "8:1", - # "1:8", ], default="auto", tooltip="If set to 'auto', matches your input image's aspect ratio; " @@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=[ - # "512px", - "1K", - "2K", - "4K", - ], + options=["1K", "2K", "4K"], tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", ), IO.Combo.Input( @@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ], is_api_node=True, price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + is_deprecated=True, ) @classmethod @@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode): ) +def _nano_banana_2_v2_model_inputs(): + return [ + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + "1:4", + "4:1", + "8:1", + "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=["1K", "2K", "4K"], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 15)], + min=0, + ), + tooltip="Optional reference image(s). Up to 14 images total.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + ] + + +class GeminiNanoBanana2V2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2V2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "Nano Banana 2 (Gemini 3.1 Flash Image)", + _nano_banana_2_v2_model_inputs(), + ), + ], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE", "IMAGE+TEXT"], + advanced=True, + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + IO.Image.Output( + display_name="thought_image", + tooltip="First image from the model's thinking process. " + "Only available with thinking_level HIGH and IMAGE+TEXT modality.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), + expr=""" + ( + $r := $lookup(widgets, "model.resolution"); + $prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + response_modalities: str, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + model_choice = model["model"] + if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model_id = "gemini-3.1-flash-image-preview" + else: + model_id = model_choice + + images = model.get("images") or {} + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images: + image_tensors: list[Input.Image] = [t for t in images.values() if t is not None] + if image_tensors: + if sum(get_number_of_images(t) for t in image_tensors) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, image_tensors)) + files = model.get("files") + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=model["resolution"]) + if model["aspect_ratio"] != "auto": + image_config.aspectRatio = model["aspect_ratio"] + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput( + await get_image_from_response(response), + get_text_from_response(response), + await get_image_from_response(response, thought=True), + ) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension): GeminiImage, GeminiImage2, GeminiNanoBanana2, + GeminiNanoBanana2V2, GeminiInputFiles, ] From 8dc3f3f2094121c0a013e21d89136ebc331d2974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 8 May 2026 03:18:28 +0300 Subject: [PATCH 086/102] Improve SAM3 large input handling (#13767) --- comfy/ldm/sam3/detector.py | 9 ++++--- comfy/ldm/sam3/tracker.py | 49 +++++++++++++++++++++++++------------- comfy_extras/nodes_sam3.py | 24 +++++++++++-------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py index 12d3a01ab..23a972ac7 100644 --- a/comfy/ldm/sam3/detector.py +++ b/comfy/ldm/sam3/detector.py @@ -561,7 +561,8 @@ class SAM3Model(nn.Module): return high_res_masks def forward_video(self, images, initial_masks, pbar=None, text_prompts=None, - new_det_thresh=0.5, max_objects=0, detect_interval=1): + new_det_thresh=0.5, max_objects=0, detect_interval=1, + target_device=None, target_dtype=None): """Track video with optional per-frame text-prompted detection.""" bb = self.detector.backbone["vision_backbone"] @@ -589,8 +590,10 @@ class SAM3Model(nn.Module): return self.tracker.track_video_with_detection( backbone_fn, images, initial_masks, detect_fn, new_det_thresh=new_det_thresh, max_objects=max_objects, - detect_interval=detect_interval, backbone_obj=bb, pbar=pbar) + detect_interval=detect_interval, backbone_obj=bb, pbar=pbar, + target_device=target_device, target_dtype=target_dtype) # SAM3 (non-multiplex) — no detection support, requires initial masks if initial_masks is None: raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking") - return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb) + return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb, + target_device=target_device, target_dtype=target_dtype) diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py index 8f7481003..8456e90a6 100644 --- a/comfy/ldm/sam3/tracker.py +++ b/comfy/ldm/sam3/tracker.py @@ -200,8 +200,13 @@ def pack_masks(masks): def unpack_masks(packed): """Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8].""" - shifts = torch.arange(8, device=packed.device) - return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool() + bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device) + return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1) + + +def _prep_frame(images, idx, device, dt, size): + """Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size).""" + return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled") def _compute_backbone(backbone_fn, frame, frame_idx=None): @@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module): # SAM3: drop last FPN level return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1] - def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None): + def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None, + target_device=None, target_dtype=None): """Track one object, computing backbone per frame to save VRAM.""" N = images.shape[0] - device, dt = images.device, images.dtype + device = target_device if target_device is not None else images.device + dt = target_dtype if target_dtype is not None else images.dtype + size = self.image_size output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} all_masks = [] for frame_idx in tqdm(range(N), desc="tracking"): vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame( - backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx) + backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx) mask_input = None if frame_idx == 0: mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt), @@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module): return torch.cat(all_masks, dim=0) # [N, 1, H, W] - def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs): + def track_video(self, backbone_fn, images, initial_masks, pbar=None, + target_device=None, target_dtype=None, **kwargs): """Track one or more objects across video frames. Args: backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame - images: [N, 3, 1008, 1008] video frames + images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size) initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object) pbar: optional progress bar @@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module): per_object = [] for obj_idx in range(N_obj): obj_masks = self._track_single_object( - backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar) + backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar, + target_device=target_device, target_dtype=target_dtype) per_object.append(obj_masks) return torch.cat(per_object, dim=1) # [N, N_obj, H, W] @@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module): return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item() return [] + INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here. + def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None, new_det_thresh=0.5, max_objects=0, detect_interval=1, - backbone_obj=None, pbar=None): + backbone_obj=None, pbar=None, target_device=None, target_dtype=None): """Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits.""" - N, device, dt = images.shape[0], images.device, images.dtype + if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS: + max_objects = self.INTERNAL_MAX_OBJECTS + N = images.shape[0] + device = target_device if target_device is not None else images.device + dt = target_dtype if target_dtype is not None else images.dtype + size = self.image_size output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} all_masks = [] idev = comfy.model_management.intermediate_device() @@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module): prefetch = True except RuntimeError: pass - cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0) + cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0) for frame_idx in tqdm(range(N), desc="tracking"): vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb @@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module): backbone_stream.wait_stream(torch.cuda.current_stream(device)) with torch.cuda.stream(backbone_stream): next_bb = self._compute_backbone_frame( - backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1) # Per-frame detection with NMS (skip if no detect_fn, or interval/max not met) det_masks = torch.empty(0, device=device) @@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module): current_out = self._condition_with_masks( initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, - images[frame_idx:frame_idx + 1], trunk_out) + _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out) last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) obj_scores = [1.0] * mux_state.total_valid_entries if keep_alive is not None: @@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module): current_out = self._condition_with_masks( det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, - images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0) + _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0) last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) obj_scores = det_scores[:mux_state.total_valid_entries].tolist() if keep_alive is not None: @@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module): torch.cuda.current_stream(device).wait_stream(backbone_stream) cur_bb = next_bb else: - cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1) continue else: N_obj = mux_state.total_valid_entries @@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module): torch.cuda.current_stream(device).wait_stream(backbone_stream) cur_bb = next_bb else: - cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1) if not all_masks or all(m is None for m in all_masks): return {"packed_masks": None, "n_frames": N, "scores": []} diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py index 5cf92ccb3..c460506bf 100644 --- a/comfy_extras/nodes_sam3.py +++ b/comfy_extras/nodes_sam3.py @@ -272,8 +272,8 @@ class SAM3_VideoTrack(io.ComfyNode): io.Model.Input("model", display_name="model"), io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), - io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), - io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), + io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."), + io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."), io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), ], outputs=[ @@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode): dtype = model.model.get_dtype() sam3_model = model.model.diffusion_model - frames = images[..., :3].movedim(-1, 1) - frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) + frames_in = images[..., :3].movedim(-1, 1) init_masks = None if initial_mask is not None: @@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode): result = sam3_model.forward_video( images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, new_det_thresh=detection_threshold, max_objects=max_objects, - detect_interval=detect_interval) + detect_interval=detect_interval, target_device=device, target_dtype=dtype) result["orig_size"] = (H, W) return io.NodeOutput(result) @@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode): cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area has = area > 1 scores = track_data.get("scores", []) + label_scale = max(3, H // 240) # Scale font with resolutio + size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks for obj_idx in range(N_obj): if has[obj_idx]: _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) color = cls.COLORS[obj_idx % len(cls.COLORS)] - SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) + obj_scale = min(label_scale, size_caps[obj_idx]) + score_scale = max(1, obj_scale * 2 // 3) + SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale) if obj_idx < len(scores) and scores[obj_idx] < 1.0: SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), - _cx, _cy + 5 * 3 + 3, color, scale=2) + _cx, _cy + 5 * obj_scale + 3, color, scale=score_scale) frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) else: frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) @@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode): if not indices: return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) - selected = packed[:, indices] - binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool - union = binary.any(dim=1, keepdim=True).float() + union_packed = packed[:, indices[0]].clone() + for i in indices[1:]: + union_packed |= packed[:, i] + union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm] mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] return io.NodeOutput(mask_out) From ef8f25601a8504647caf9c9213a7c41a9f414901 Mon Sep 17 00:00:00 2001 From: Talmaj Date: Fri, 8 May 2026 03:38:36 +0200 Subject: [PATCH 087/102] Add I2V for causal forcing model. (#13719) --- comfy/k_diffusion/sampling.py | 17 +++++++++++ comfy_extras/nodes_ar_video.py | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index d33bc7199..c53ac4b2b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No output = torch.zeros_like(x) s_in = x.new_ones([x.shape[0]]) current_start_frame = 0 + + # I2V: seed KV cache with the initial image latent before the denoising loop + initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None) + if initial_latent is not None: + initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype) + n_init = initial_latent.shape[2] + output[:, :, :n_init] = initial_latent + + ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches} + transformer_options["ar_state"] = ar_state + zero_sigma = sigmas.new_zeros([1]) + _ = model(initial_latent, zero_sigma * s_in, **extra_args) + + current_start_frame = n_init + remaining = lat_t - n_init + num_blocks = -(-remaining // num_frame_per_block) + num_sigma_steps = len(sigmas) - 1 total_real_steps = num_blocks * num_sigma_steps step_count = 0 diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 09ee886fd..b36588b14 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -2,6 +2,7 @@ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors - SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop + - ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image) """ import torch @@ -9,6 +10,7 @@ from typing_extensions import override import comfy.model_management import comfy.samplers +import comfy.utils from comfy_api.latest import ComfyExtension, io @@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode): return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options)) +class ARVideoI2V(io.ComfyNode): + """Image-to-video setup for AR video models (Causal Forcing, Self-Forcing). + + VAE-encodes the start image and stores it in the model's transformer_options + so that sample_ar_video can seed the KV cache before denoising. + Uses the same T2V model checkpoint -- no separate I2V architecture needed. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ARVideoI2V", + category="conditioning/video_models", + inputs=[ + io.Model.Input("model"), + io.Vae.Input("vae"), + io.Image.Input("start_image"), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Model.Output(display_name="MODEL"), + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput: + start_image = comfy.utils.common_upscale( + start_image[:1].movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + + initial_latent = vae.encode(start_image[:, :, :, :3]) + + m = model.clone() + to = m.model_options.setdefault("transformer_options", {}) + ar_cfg = to.setdefault("ar_config", {}) + ar_cfg["initial_latent"] = initial_latent + + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput(m, {"samples": latent}) + + class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyARVideoLatent, SamplerARVideo, + ARVideoI2V, ] From df7bf1d3dc852365593786497123d92440ac1852 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 7 May 2026 19:04:30 -0700 Subject: [PATCH 088/102] Update warning message for ComfyUI frontend installation. (#13796) --- app/frontend_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index f753ef0de..7108bd35a 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -27,7 +27,7 @@ def frontend_install_warning_message(): return f""" {get_missing_requirements_message()} -This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. +The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code. """.strip() def parse_version(version: str) -> tuple[int, int, int]: From c8673542f762910766691345401e09caef2bc9a6 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 7 May 2026 19:21:12 -0700 Subject: [PATCH 089/102] fix: make NodeReplaceManager.register() idempotent (#13596) --- app/node_replace_manager.py | 20 ++++- .../app_test/node_replace_manager_test.py | 90 +++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 tests-unit/app_test/node_replace_manager_test.py diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index d9aab5b22..72e8ac2b1 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -1,5 +1,7 @@ from __future__ import annotations +import logging + from aiohttp import web from typing import TYPE_CHECKING, TypedDict @@ -31,8 +33,22 @@ class NodeReplaceManager: self._replacements: dict[str, list[NodeReplace]] = {} def register(self, node_replace: NodeReplace): - """Register a node replacement mapping.""" - self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) + """Register a node replacement mapping. + + Idempotent: if a replacement with the same (old_node_id, new_node_id) + is already registered, the duplicate is ignored. This prevents stale + entries from accumulating when custom nodes are reloaded in the same + process (e.g. via ComfyUI-Manager). + """ + existing = self._replacements.setdefault(node_replace.old_node_id, []) + for entry in existing: + if entry.new_node_id == node_replace.new_node_id: + logging.debug( + "Node replacement %s -> %s already registered, ignoring duplicate.", + node_replace.old_node_id, node_replace.new_node_id, + ) + return + existing.append(node_replace) def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: """Get replacements for an old node ID.""" diff --git a/tests-unit/app_test/node_replace_manager_test.py b/tests-unit/app_test/node_replace_manager_test.py new file mode 100644 index 000000000..8a3fd18bb --- /dev/null +++ b/tests-unit/app_test/node_replace_manager_test.py @@ -0,0 +1,90 @@ +"""Tests for NodeReplaceManager registration behavior.""" +import importlib +import sys +import types + +import pytest + + +@pytest.fixture +def NodeReplaceManager(monkeypatch): + """Provide NodeReplaceManager with `nodes` stubbed. + + `app.node_replace_manager` does `import nodes` at module level, which pulls in + torch + the full ComfyUI graph. register() doesn't actually need it, so we + stub `nodes` per-test (via monkeypatch so it's torn down) and reload the + module so it picks up the stub instead of any cached real import. + """ + fake_nodes = types.ModuleType("nodes") + fake_nodes.NODE_CLASS_MAPPINGS = {} + monkeypatch.setitem(sys.modules, "nodes", fake_nodes) + monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False) + module = importlib.import_module("app.node_replace_manager") + yield module.NodeReplaceManager + # Drop the freshly-imported module so the next test (or a later real import + # of `nodes`) starts from a clean slate. + sys.modules.pop("app.node_replace_manager", None) + + +class FakeNodeReplace: + """Lightweight stand-in for comfy_api.latest._io.NodeReplace.""" + def __init__(self, new_node_id, old_node_id, old_widget_ids=None, + input_mapping=None, output_mapping=None): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + +def test_register_adds_replacement(NodeReplaceManager): + manager = NodeReplaceManager() + manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode")) + assert manager.has_replacement("OldNode") + assert len(manager.get_replacement("OldNode")) == 1 + + +def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager): + """Different new_node_ids for the same old_node_id should all be kept.""" + manager = NodeReplaceManager() + manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode")) + manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode")) + replacements = manager.get_replacement("OldNode") + assert len(replacements) == 2 + assert {r.new_node_id for r in replacements} == {"AltA", "AltB"} + + +def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager): + """Re-registering the same (old_node_id, new_node_id) should be a no-op.""" + manager = NodeReplaceManager() + manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode")) + manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode")) + manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode")) + assert len(manager.get_replacement("OldNode")) == 1 + + +def test_register_idempotent_preserves_first_registration(NodeReplaceManager): + """First registration wins; later duplicates with different mappings are ignored.""" + manager = NodeReplaceManager() + first = FakeNodeReplace( + new_node_id="NewNode", old_node_id="OldNode", + input_mapping=[{"new_id": "a", "old_id": "x"}], + ) + second = FakeNodeReplace( + new_node_id="NewNode", old_node_id="OldNode", + input_mapping=[{"new_id": "b", "old_id": "y"}], + ) + manager.register(first) + manager.register(second) + replacements = manager.get_replacement("OldNode") + assert len(replacements) == 1 + assert replacements[0] is first + + +def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager): + manager = NodeReplaceManager() + manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA")) + manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA")) + manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB")) + assert len(manager.get_replacement("OldA")) == 1 + assert len(manager.get_replacement("OldB")) == 1 From 594de378fe1d2e32128338f5cc57864ee1d9d96f Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Fri, 8 May 2026 13:02:55 +0800 Subject: [PATCH 090/102] Update nodes categories and display names (CORE-89) (#13786) --- comfy_extras/nodes_advanced_samplers.py | 2 +- comfy_extras/nodes_attention_multiply.py | 8 ++++---- comfy_extras/nodes_audio_encoder.py | 1 + comfy_extras/nodes_camera_trajectory.py | 2 +- comfy_extras/nodes_cond.py | 4 ++-- comfy_extras/nodes_context_windows.py | 2 +- comfy_extras/nodes_custom_sampler.py | 4 ++-- comfy_extras/nodes_differential_diffusion.py | 2 +- comfy_extras/nodes_fresca.py | 2 +- comfy_extras/nodes_hunyuan.py | 4 ++++ comfy_extras/nodes_hunyuan3d.py | 6 ++++-- comfy_extras/nodes_hypernetwork.py | 1 + comfy_extras/nodes_lora_extract.py | 2 +- comfy_extras/nodes_lt.py | 3 ++- comfy_extras/nodes_mahiro.py | 2 +- comfy_extras/nodes_math.py | 2 +- comfy_extras/nodes_number_convert.py | 2 +- comfy_extras/nodes_perpneg.py | 7 ++++--- comfy_extras/nodes_photomaker.py | 4 ++-- comfy_extras/nodes_post_processing.py | 4 +++- comfy_extras/nodes_rtdetr.py | 4 ++-- comfy_extras/nodes_sag.py | 2 +- comfy_extras/nodes_sam3.py | 8 ++++---- comfy_extras/nodes_stable_cascade.py | 2 +- comfy_extras/nodes_textgen.py | 4 +++- comfy_extras/nodes_torch_compile.py | 2 +- comfy_extras/nodes_train.py | 2 +- comfy_extras/nodes_video_model.py | 2 +- custom_nodes/websocket_image_save.py | 6 +++++- nodes.py | 14 +++++++------ .../testing-pack/api_test_nodes.py | 4 ++-- .../testing-pack/async_test_nodes.py | 20 +++++++++---------- .../testing-pack/specific_tests.py | 6 +++--- 33 files changed, 80 insertions(+), 60 deletions(-) diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 7f716cd76..7e8411fa4 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode): return io.Schema( node_id="SamplerEulerCFGpp", display_name="SamplerEulerCFG++", - category="_for_testing", # "sampling/custom_sampling/samplers" + category="experimental", # "sampling/custom_sampling/samplers" inputs=[ io.Combo.Input("version", options=["regular", "alternative"], advanced=True), ], diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 060a5c9be..f4ee6a689 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="UNetSelfAttentionMultiply", - category="_for_testing/attention_experiments", + category="experimental/attention_experiments", inputs=[ io.Model.Input("model"), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), @@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="UNetCrossAttentionMultiply", - category="_for_testing/attention_experiments", + category="experimental/attention_experiments", inputs=[ io.Model.Input("model"), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), @@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode): return io.Schema( node_id="CLIPAttentionMultiply", search_aliases=["clip attention scale", "text encoder attention"], - category="_for_testing/attention_experiments", + category="experimental/attention_experiments", inputs=[ io.Clip.Input("clip"), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), @@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="UNetTemporalAttentionMultiply", - category="_for_testing/attention_experiments", + category="experimental/attention_experiments", inputs=[ io.Model.Input("model"), io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py index 13aacd41a..6a85da89b 100644 --- a/comfy_extras/nodes_audio_encoder.py +++ b/comfy_extras/nodes_audio_encoder.py @@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AudioEncoderLoader", + display_name="Load Audio Encoder", category="loaders", inputs=[ io.Combo.Input( diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index e7efa29ba..34b78e81b 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanCameraEmbedding", - category="camera", + category="conditioning/video_models", inputs=[ io.Combo.Input( "camera_pose", diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 86426a780..b745a43af 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CLIPTextEncodeControlnet", - category="_for_testing/conditioning", + category="experimental/conditioning", inputs=[ io.Clip.Input("clip"), io.Conditioning.Input("conditioning"), @@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="T5TokenizerOptions", - category="_for_testing/conditioning", + category="experimental/conditioning", inputs=[ io.Clip.Input("clip"), io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True), diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index fefc56d26..f7ca833dc 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode): return io.Schema( node_id="ContextWindowsManual", display_name="Context Windows (Manual)", - category="context", + category="model_patches", description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 1e957c09b..c67145d2d 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="AddNoise", - category="_for_testing/custom_sampling/noise", + category="experimental/custom_sampling/noise", is_experimental=True, inputs=[ io.Model.Input("model"), @@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode): return io.Schema( node_id="ManualSigmas", search_aliases=["custom noise schedule", "define sigmas"], - category="_for_testing/custom_sampling", + category="experimental/custom_sampling", is_experimental=True, inputs=[ io.String.Input("sigmas", default="1, 0.5", multiline=False) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 34ffb9a89..4fa61ad0e 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode): node_id="DifferentialDiffusion", search_aliases=["inpaint gradient", "variable denoise strength"], display_name="Differential Diffusion", - category="_for_testing", + category="experimental", inputs=[ io.Model.Input("model"), io.Float.Input( diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index eab4f303f..173f42154 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -60,7 +60,7 @@ class FreSca(io.ComfyNode): node_id="FreSca", search_aliases=["frequency guidance"], display_name="FreSca", - category="_for_testing", + category="experimental", description="Applies frequency-dependent scaling to the guidance", inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 4ea93a499..9e4873be5 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanVideo15SuperResolution", + display_name="Hunyuan Video 1.5 Super Resolution", + category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanRefinerLatent", + display_name="Hunyuan Latent Refiner", + category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index fa55ead59..bf18ecb88 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2Conditioning", - category="conditioning/video_models", + category="conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("clip_vision_output"), ], @@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2ConditioningMultiView", - category="conditioning/video_models", + category="conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("front", optional=True), IO.ClipVisionOutput.Input("left", optional=True), @@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VoxelToMeshBasic", + display_name="Voxel to Mesh (Basic)", category="3d", inputs=[ IO.Voxel.Input("voxel"), @@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VoxelToMesh", + display_name="Voxel to Mesh", category="3d", inputs=[ IO.Voxel.Input("voxel"), diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 2a6a87a81..44a9c6f97 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="HypernetworkLoader", + display_name="Load Hypernetwork", category="loaders", inputs=[ IO.Model.Input("model"), diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index 975f90f45..bcd249c29 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode): node_id="LoraSave", search_aliases=["export lora"], display_name="Extract and Save Lora", - category="_for_testing", + category="experimental", inputs=[ io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"), io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True), diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 19d8a387f..ab1359fdb 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -594,7 +594,8 @@ class LTXVPreprocess(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVPreprocess", - category="image", + display_name="LTXV Preprocess", + category="video/preprocessors", inputs=[ io.Image.Input("image"), io.Int.Input( diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index a25226e6d..7bd5f6652 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode): return io.Schema( node_id="Mahiro", display_name="Positive-Biased Guidance", - category="_for_testing", + category="experimental", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 6417bacf1..8f6e687d2 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode): return io.Schema( node_id="ComfyMathExpression", display_name="Math Expression", - category="math", + category="logic", search_aliases=[ "expression", "formula", "calculate", "calculator", "eval", "math", diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py index cac7e736d..ab3f2aa8a 100644 --- a/comfy_extras/nodes_number_convert.py +++ b/comfy_extras/nodes_number_convert.py @@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode): return io.Schema( node_id="ComfyNumberConvert", display_name="Number Convert", - category="math", + category="utils", search_aliases=[ "int to float", "float to int", "number convert", "int2float", "float2int", "cast", "parse number", diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index ed1467de9..a7a72d1bc 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PerpNeg", - display_name="Perp-Neg (DEPRECATED by PerpNegGuider)", - category="_for_testing", + display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)", + category="experimental", inputs=[ io.Model.Input("model"), io.Conditioning.Input("empty_conditioning"), @@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PerpNegGuider", - category="_for_testing", + display_name="Perp-Neg Guider", + category="experimental", inputs=[ io.Model.Input("model"), io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index 228183c07..8a2248572 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PhotoMakerLoader", - category="_for_testing/photomaker", + category="experimental/photomaker", inputs=[ io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), ], @@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PhotoMakerEncode", - category="_for_testing/photomaker", + category="experimental/photomaker", inputs=[ io.Photomaker.Input("photomaker"), io.Image.Input("image"), diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index d938a2035..1fa14d2d2 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -116,6 +116,7 @@ class Quantize(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageQuantize", + display_name="Quantize Image", category="image/postprocessing", inputs=[ io.Image.Input("image"), @@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageSharpen", + display_name="Sharpen Image", category="image/postprocessing", inputs=[ io.Image.Input("image"), @@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode): node_id="ResizeImageMaskNode", display_name="Resize Image/Mask", description="Resize an image or mask using various scaling methods.", - category="transform", + category="image/transform", search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"], inputs=[ io.MatchType.Input("input", template=template), diff --git a/comfy_extras/nodes_rtdetr.py b/comfy_extras/nodes_rtdetr.py index 7feaf3ab3..a321577c7 100644 --- a/comfy_extras/nodes_rtdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode): return io.Schema( node_id="RTDETR_detect", display_name="RT-DETR Detect", - category="detection/", + category="detection", search_aliases=["bbox", "bounding box", "object detection", "coco"], inputs=[ io.Model.Input("model", display_name="model"), @@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode): return io.Schema( node_id="DrawBBoxes", display_name="Draw BBoxes", - category="detection/", + category="detection", search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"], inputs=[ io.Image.Input("image", optional=True), diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index d9c47851c..9dbf1b6f9 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode): return io.Schema( node_id="SelfAttentionGuidance", display_name="Self-Attention Guidance", - category="_for_testing", + category="experimental", inputs=[ io.Model.Input("model"), io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py index c460506bf..4ea9221e9 100644 --- a/comfy_extras/nodes_sam3.py +++ b/comfy_extras/nodes_sam3.py @@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode): return io.Schema( node_id="SAM3_Detect", display_name="SAM3 Detect", - category="detection/", + category="detection", search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], inputs=[ io.Model.Input("model", display_name="model"), @@ -265,7 +265,7 @@ class SAM3_VideoTrack(io.ComfyNode): return io.Schema( node_id="SAM3_VideoTrack", display_name="SAM3 Video Track", - category="detection/", + category="detection", search_aliases=["sam3", "video", "track", "propagate"], inputs=[ io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), @@ -320,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode): return io.Schema( node_id="SAM3_TrackPreview", display_name="SAM3 Track Preview", - category="detection/", + category="detection", inputs=[ SAM3TrackData.Input("track_data", display_name="track_data"), io.Image.Input("images", display_name="images", optional=True), @@ -478,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode): return io.Schema( node_id="SAM3_TrackToMask", display_name="SAM3 Track to Mask", - category="detection/", + category="detection", inputs=[ SAM3TrackData.Input("track_data", display_name="track_data"), io.String.Input("object_indices", display_name="object_indices", default="", diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 8c1aebca9..0dc6c9fcd 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_SuperResolutionControlnet", - category="_for_testing/stable_cascade", + category="experimental/stable_cascade", is_experimental=True, inputs=[ io.Image.Input("image"), diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 1661a1011..d52faf815 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode): return io.Schema( node_id="TextGenerate", - category="textgen", + display_name="Generate Text", + category="text", search_aliases=["LLM", "gemma"], inputs=[ io.Clip.Input("clip"), @@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate): parent_schema = super().define_schema() return io.Schema( node_id="TextGenerateLTX2Prompt", + display_name="Generate LTX2 Prompt", category=parent_schema.category, inputs=parent_schema.inputs, outputs=parent_schema.outputs, diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index c9e2e0026..d4506b1a9 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="TorchCompileModel", - category="_for_testing", + category="experimental", inputs=[ io.Model.Input("model"), io.Combo.Input( diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0616dfc2d..e9871369b 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode): node_id="SaveLoRA", search_aliases=["export lora"], display_name="Save LoRA Weights", - category="loaders", + category="advanced/model_merging", is_experimental=True, is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index bf98e6b82..0f3881a24 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders/video_models" + CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py index 15f87f9f5..6a8646d0e 100644 --- a/custom_nodes/websocket_image_save.py +++ b/custom_nodes/websocket_image_save.py @@ -22,7 +22,7 @@ class SaveImageWebsocket: OUTPUT_NODE = True - CATEGORY = "api/image" + CATEGORY = "image" def save_images(self, images): pbar = comfy.utils.ProgressBar(images.shape[0]) @@ -42,3 +42,7 @@ class SaveImageWebsocket: NODE_CLASS_MAPPINGS = { "SaveImageWebsocket": SaveImageWebsocket, } + +NODE_DISPLAY_NAME_MAPPINGS = { + "SaveImageWebsocket": "Save Image (Websocket)", +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index ad0cbc675..ae9e70cb9 100644 --- a/nodes.py +++ b/nodes.py @@ -330,7 +330,7 @@ class VAEDecodeTiled: RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" - CATEGORY = "_for_testing" + CATEGORY = "experimental" def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: @@ -377,7 +377,7 @@ class VAEEncodeTiled: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "_for_testing" + CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) @@ -493,7 +493,7 @@ class SaveLatent: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "experimental" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) @@ -538,7 +538,7 @@ class LoadLatent: files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } - CATEGORY = "_for_testing" + CATEGORY = "experimental" RETURN_TYPES = ("LATENT", ) FUNCTION = "load" @@ -1443,7 +1443,7 @@ class LatentBlend: RETURN_TYPES = ("LATENT",) FUNCTION = "blend" - CATEGORY = "_for_testing" + CATEGORY = "experimental" def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"): @@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "StyleModelLoader": "Load Style Model", "CLIPVisionLoader": "Load CLIP Vision", "UNETLoader": "Load Diffusion Model", + "unCLIPCheckpointLoader": "Load unCLIP Checkpoint", + "GLIGENLoader": "Load GLIGEN Model", # Conditioning "CLIPVisionEncode": "CLIP Vision Encode", "StyleModelApply": "Apply Style Model", @@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageSharpen": "Sharpen Image", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", "GetImageSize": "Get Image Size", - # _for_testing + # experimental "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", } diff --git a/tests/execution/testing_nodes/testing-pack/api_test_nodes.py b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py index b2eaae05e..70c2a9e95 100644 --- a/tests/execution/testing_nodes/testing-pack/api_test_nodes.py +++ b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py @@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "execute" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def execute(self, value, sleep_seconds): start = time.time() @@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "execute" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" def execute(self, value, sleep_seconds): start = time.time() diff --git a/tests/execution/testing_nodes/testing-pack/async_test_nodes.py b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py index 547eea6f4..589dabf17 100644 --- a/tests/execution/testing_nodes/testing-pack/async_test_nodes.py +++ b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py @@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC): RETURN_TYPES = ("IMAGE",) FUNCTION = "process" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" @classmethod async def VALIDATE_INPUTS(cls, value, threshold): @@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "error_execution" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def error_execution(self, value, error_after): await asyncio.sleep(error_after) @@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC): RETURN_TYPES = ("IMAGE",) FUNCTION = "process" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" @classmethod async def VALIDATE_INPUTS(cls, value, max_value): @@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "timeout_execution" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def timeout_execution(self, value, timeout, operation_time): try: @@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "sync_error" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" def sync_error(self, value): raise RuntimeError("Intentional sync execution error for testing") @@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC): RETURN_TYPES = ("IMAGE",) FUNCTION = "process" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def check_lazy_status(self, condition, input1, input2): # Simulate async checking (e.g., querying remote service) @@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC): RETURN_TYPES = ("IMAGE",) FUNCTION = "generate_async_workflow" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): g = GraphBuilder() @@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "use_resource" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def use_resource(self, value, resource_id, duration): # Check if resource is already in use @@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC): RETURN_TYPES = ("IMAGE",) FUNCTION = "process_batch" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def process_batch(self, images, process_time_per_item, unique_id): batch_size = images.shape[0] @@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "limited_execution" - CATEGORY = "_for_testing/async" + CATEGORY = "experimental/async" async def limited_execution(self, value, duration, node_id): async with self._semaphore: diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index 4f8f01ae4..2eb5d520e 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC): RETURN_TYPES = (IO.ANY,) FUNCTION = "sleep" - CATEGORY = "_for_testing" + CATEGORY = "experimental" async def sleep(self, value, seconds, unique_id): pbar = ProgressBar(seconds, node_id=unique_id) @@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC): } RETURN_TYPES = ("IMAGE",) FUNCTION = "parallel_sleep" - CATEGORY = "_for_testing" + CATEGORY = "experimental" OUTPUT_NODE = True def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id): @@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput: } RETURN_TYPES = ("IMAGE",) FUNCTION = "process" - CATEGORY = "_for_testing" + CATEGORY = "experimental" OUTPUT_NODE = True def process(self, image, value): From 56c74094c7c2ccbcf23f2aca1e4000199934da13 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 8 May 2026 09:39:13 +0300 Subject: [PATCH 091/102] [Partner Nodes] use "adaptive" aspect ratio for SD2 nodes (#13800) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_bytedance.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 2f241a775..5f74f4a14 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( ) -def _seedance2_text_inputs(resolutions: list[str]): +def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"): return [ IO.String.Input( "prompt", @@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]): IO.Combo.Input( "ratio", options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"], + default=default_ratio, tooltip="Aspect ratio of the output video.", ), IO.Int.Input( @@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), - IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), + IO.DynamicCombo.Option( + "Seedance 2.0", + _seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), + ), + IO.DynamicCombo.Option( + "Seedance 2.0 Fast", + _seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"), + ), ], tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", ), @@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) -def _seedance2_reference_inputs(resolutions: list[str]): +def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"): return [ - *_seedance2_text_inputs(resolutions), + *_seedance2_text_inputs(resolutions, default_ratio=default_ratio), IO.Autogrow.Input( "reference_images", template=IO.Autogrow.TemplateNames( @@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])), - IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])), + IO.DynamicCombo.Option( + "Seedance 2.0", + _seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), + ), + IO.DynamicCombo.Option( + "Seedance 2.0 Fast", + _seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"), + ), ], tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", ), From bac6fc35fbf3fb2a6fc7e54fce17203215bcfff5 Mon Sep 17 00:00:00 2001 From: omahs <73983677+omahs@users.noreply.github.com> Date: Fri, 8 May 2026 11:14:45 +0200 Subject: [PATCH 092/102] Fix typos (#10986) --- comfy/hooks.py | 2 +- comfy/ldm/modules/diffusionmodules/util.py | 2 +- comfy_extras/nodes_flux.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 1a76c7ba4..5458fc3d8 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -93,7 +93,7 @@ class Hook: self.hook_scope = hook_scope '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' + '''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 233011dc9..aed5c149c 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): alphas = alphacums[ddim_timesteps] alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - # according the the formula provided in https://arxiv.org/abs/2010.02502 + # according to the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 3a23c7d04..5e04a5f77 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode): append = execute # TODO: remove -PREFERED_KONTEXT_RESOLUTIONS = [ +PREFERRED_KONTEXT_RESOLUTIONS = [ (672, 1568), (688, 1504), (720, 1456), @@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode): width = image.shape[2] height = image.shape[1] aspect_ratio = width / height - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) return io.NodeOutput(image) From d3c18c163665a6f94e7dc56823aabcb93ebf7e5e Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 12:59:24 +0300 Subject: [PATCH 093/102] Add support for BiRefNet background remove model (CORE-46) (#12747) --- comfy/background_removal/birefnet.json | 7 + comfy/background_removal/birefnet.py | 689 ++++++++++++++++++ comfy/bg_removal_model.py | 78 ++ comfy/ops.py | 22 + comfy_api/latest/_io.py | 7 + comfy_extras/nodes_bg_removal.py | 60 ++ comfy_extras/nodes_mask.py | 27 +- folder_paths.py | 2 + .../put_background_removal_models_here | 0 nodes.py | 1 + 10 files changed, 887 insertions(+), 6 deletions(-) create mode 100644 comfy/background_removal/birefnet.json create mode 100644 comfy/background_removal/birefnet.py create mode 100644 comfy/bg_removal_model.py create mode 100644 comfy_extras/nodes_bg_removal.py create mode 100644 models/background_removal/put_background_removal_models_here diff --git a/comfy/background_removal/birefnet.json b/comfy/background_removal/birefnet.json new file mode 100644 index 000000000..f0960af39 --- /dev/null +++ b/comfy/background_removal/birefnet.json @@ -0,0 +1,7 @@ +{ + "model_type": "birefnet", + "image_std": [1.0, 1.0, 1.0], + "image_mean": [0.0, 0.0, 0.0], + "image_size": 1024, + "resize_to_original": true +} diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py new file mode 100644 index 000000000..df54b2b90 --- /dev/null +++ b/comfy/background_removal/birefnet.py @@ -0,0 +1,689 @@ +import torch +import comfy.ops +import numpy as np +import torch.nn as nn +from functools import partial +import torch.nn.functional as F +from torchvision.ops import deform_conv2d +from comfy.ldm.modules.attention import optimized_attention_for_device + +CXT = [3072, 1536, 768, 384][1:][::-1][-3:] + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x): + B, N, C = x.shape + optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True) + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + x = optimized_attention( + q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True + ).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + + return x + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype) + self.act = nn.GELU() + self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype)) + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x + + +class SwinTransformerBlock(nn.Module): + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, + norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim, device=device, dtype=dtype) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations) + + self.norm2 = norm_layer(dim, device=device, dtype=dtype) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + B, L, C = x.shape + H, W = self.H, self.W + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + x_windows = window_partition(shifted_x, self.window_size) + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + attn_windows = self.attn(x_windows, mask=attn_mask) + + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, dim, device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype) + self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype) + + def forward(self, x, H, W): + B, L, C = x.shape + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + norm_layer=nn.LayerNorm, + downsample=None, + device=None, dtype=None, operations=None): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + device=device, dtype=dtype, operations=operations) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations) + else: + self.downsample = None + + def forward(self, x, H, W): + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None): + super().__init__() + patch_size = (patch_size, patch_size) + self.patch_size = patch_size + + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype) + if norm_layer is not None: + self.norm = norm_layer(embed_dim, device=device, dtype=dtype) + else: + self.norm = None + + def forward(self, x): + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + device=None, dtype=None, operations=None): + super().__init__() + + norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype) + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, + device=device, dtype=dtype, operations=operations, + norm_layer=norm_layer if self.patch_norm else None) + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + device=device, dtype=dtype, operations=operations) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + + def forward(self, x): + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + + outs = [] + x = x.flatten(2).transpose(1, 2) + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + +class DeformableConv2d(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, device=None, dtype=None, operations=None): + + super(DeformableConv2d, self).__init__() + + kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size) + self.stride = stride if type(stride) is tuple else (stride, stride) + self.padding = padding + + self.offset_conv = operations.Conv2d(in_channels, + 2 * kernel_size[0] * kernel_size[1], + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=True, device=device, dtype=dtype) + + self.modulator_conv = operations.Conv2d(in_channels, + 1 * kernel_size[0] * kernel_size[1], + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=True, device=device, dtype=dtype) + + self.regular_conv = operations.Conv2d(in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=bias, device=device, dtype=dtype) + + def forward(self, x): + offset = self.offset_conv(x) + modulator = 2. * torch.sigmoid(self.modulator_conv(x)) + weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True) + + x = deform_conv2d( + input=x, + offset=offset, + weight=weight, + bias=None, + padding=self.padding, + mask=modulator, + stride=self.stride, + ) + comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info) + return x + +class BasicDecBlk(nn.Module): + def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None): + super(BasicDecBlk, self).__init__() + inter_channels = 64 + self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype) + self.relu_in = nn.ReLU(inplace=True) + self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations) + self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype) + self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype) + self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv_in(x) + x = self.bn_in(x) + x = self.relu_in(x) + x = self.dec_att(x) + x = self.conv_out(x) + x = self.bn_out(x) + return x + + +class BasicLatBlk(nn.Module): + def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None): + super(BasicLatBlk, self).__init__() + self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return x + + +class _ASPPModuleDeformable(nn.Module): + def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations): + super(_ASPPModuleDeformable, self).__init__() + self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, + stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations) + self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.atrous_conv(x) + x = self.bn(x) + + return self.relu(x) + + +class ASPPDeformable(nn.Module): + def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None): + super(ASPPDeformable, self).__init__() + self.down_scale = 1 + if out_channels is None: + out_channels = in_channels + self.in_channelster = 256 // self.down_scale + + self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations) + self.aspp_deforms = nn.ModuleList([ + _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations) + for conv_size in parallel_block_sizes + ]) + + self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), + operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype), + operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype), + nn.ReLU(inplace=True)) + self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype) + self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x1 = self.aspp1(x) + x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] + x5 = self.global_avg_pool(x) + x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + return x + +class BiRefNet(nn.Module): + def __init__(self, config=None, dtype=None, device=None, operations=None): + super(BiRefNet, self).__init__() + self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations) + + channels = [1536, 768, 384, 192] + channels = [c * 2 for c in channels] + self.cxt = channels[1:][::-1][-3:] + self.squeeze_module = nn.Sequential(*[ + BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations) + for _ in range(1) + ]) + + self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations) + + def forward_enc(self, x): + x1, x2, x3, x4 = self.bb(x) + B, C, H, W = x.shape + x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) + x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) + x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) + x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) + x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) + x4 = torch.cat( + ( + *[ + F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), + F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), + F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), + ][-len(CXT):], + x4 + ), + dim=1 + ) + return (x1, x2, x3, x4) + + def forward_ori(self, x): + (x1, x2, x3, x4) = self.forward_enc(x) + x4 = self.squeeze_module(x4) + features = [x, x1, x2, x3, x4] + scaled_preds = self.decoder(features) + return scaled_preds + + def forward(self, pixel_values, intermediate_output=None): + scaled_preds = self.forward_ori(pixel_values) + return scaled_preds + + +class Decoder(nn.Module): + def __init__(self, channels, device, dtype, operations): + super(Decoder, self).__init__() + # factory kwargs + fk = {"device":device, "dtype":dtype, "operations":operations} + DecoderBlock = partial(BasicDecBlk, **fk) + LateralBlock = partial(BasicLatBlk, **fk) + DBlock = partial(SimpleConvs, **fk) + + self.split = True + N_dec_ipt = 64 + ic = 64 + ipt_cha_opt = 1 + self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) + self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) + self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic) + self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic) + self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic) + + self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1]) + self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2]) + self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3]) + self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2) + + fk = {"device":device, "dtype":dtype} + + self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk)) + + self.lateral_block4 = LateralBlock(channels[1], channels[1]) + self.lateral_block3 = LateralBlock(channels[2], channels[2]) + self.lateral_block2 = LateralBlock(channels[3], channels[3]) + + self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk) + self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk) + self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk) + + _N = 16 + + self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True)) + self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True)) + self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True)) + + [setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)] + [setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)] + + def get_patches_batch(self, x, p): + _size_h, _size_w = p.shape[2:] + patches_batch = [] + for idx in range(x.shape[0]): + columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1) + patches_x = [] + for column_x in columns_x: + patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)] + patch_sample = torch.cat(patches_x, dim=1) + patches_batch.append(patch_sample) + return torch.cat(patches_batch, dim=0) + + def forward(self, features): + x, x1, x2, x3, x4 = features + + patches_batch = self.get_patches_batch(x, x4) if self.split else x + x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) + p4 = self.decoder_block4(x4) + p4_gdt = self.gdt_convs_4(p4) + gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid() + p4 = p4 * gdt_attn_4 + _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) + _p3 = _p4 + self.lateral_block4(x3) + + patches_batch = self.get_patches_batch(x, _p3) if self.split else x + _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) + p3 = self.decoder_block3(_p3) + + p3_gdt = self.gdt_convs_3(p3) + gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() + p3 = p3 * gdt_attn_3 + _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) + _p2 = _p3 + self.lateral_block3(x2) + + patches_batch = self.get_patches_batch(x, _p2) if self.split else x + _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) + p2 = self.decoder_block2(_p2) + + p2_gdt = self.gdt_convs_2(p2) + gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid() + p2 = p2 * gdt_attn_2 + + _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) + _p1 = _p2 + self.lateral_block2(x1) + + patches_batch = self.get_patches_batch(x, _p1) if self.split else x + _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) + _p1 = self.decoder_block1(_p1) + _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) + + patches_batch = self.get_patches_batch(x, _p1) if self.split else x + _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) + p1_out = self.conv_out1(_p1) + return p1_out + + +class SimpleConvs(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype) + self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype) + + def forward(self, x): + return self.conv_out(self.conv1(x)) diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py new file mode 100644 index 000000000..cb7c2ee53 --- /dev/null +++ b/comfy/bg_removal_model.py @@ -0,0 +1,78 @@ +from .utils import load_torch_file +import os +import json +import torch +import logging + +import comfy.ops +import comfy.model_patcher +import comfy.model_management +import comfy.clip_model +import comfy.background_removal.birefnet + +BG_REMOVAL_MODELS = { + "birefnet": comfy.background_removal.birefnet.BiRefNet +} + +class BackgroundRemovalModel(): + def __init__(self, json_config): + with open(json_config) as f: + config = json.load(f) + + self.image_size = config.get("image_size", 1024) + self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0]) + self.image_std = config.get("image_std", [1.0, 1.0, 1.0]) + self.model_type = config.get("model_type", "birefnet") + self.config = config.copy() + model_class = BG_REMOVAL_MODELS.get(self.model_type) + + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) + self.model.eval() + + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + + def get_sd(self): + return self.model.state_dict() + + def encode_image(self, image): + comfy.model_management.load_model_gpu(self.patcher) + H, W = image.shape[1], image.shape[2] + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False) + out = self.model(pixel_values=pixel_values) + out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) + + mask = out.sigmoid() + if mask.ndim == 3: + mask = mask.unsqueeze(0) + if mask.shape[1] != 1: + mask = mask.movedim(-1, 1) + + return mask + + +def load_background_removal_model(sd): + if "bb.layers.1.blocks.0.attn.relative_position_index" in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json") + else: + return None + + bg_model = BackgroundRemovalModel(json_config) + m, u = bg_model.load_sd(sd) + if len(m) > 0: + logging.warning("missing background removal: {}".format(m)) + u = set(u) + keys = list(sd.keys()) + for k in keys: + if k not in u: + sd.pop(k) + return bg_model + +def load(ckpt_path): + sd = load_torch_file(ckpt_path) + return load_background_removal_model(sd) diff --git a/comfy/ops.py b/comfy/ops.py index 585c185a3..77ad1d527 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -562,6 +562,25 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None + running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None + x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + + def forward(self, *args, **kwargs): + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -749,6 +768,9 @@ class manual_cast(disable_weight_init): class Conv3d(disable_weight_init.Conv3d): comfy_cast_weights = True + class BatchNorm2d(disable_weight_init.BatchNorm2d): + comfy_cast_weights = True + class GroupNorm(disable_weight_init.GroupNorm): comfy_cast_weights = True diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e50266bc5..5ed968960 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from spandrel import ImageModelDescriptor from comfy.clip_vision import ClipVisionModel from comfy.clip_vision import Output as ClipVisionOutput_ + from comfy.bg_removal_model import BackgroundRemovalModel from comfy.controlnet import ControlNet from comfy.hooks import HookGroup, HookKeyframeGroup from comfy.model_patcher import ModelPatcher @@ -614,6 +615,11 @@ class Model(ComfyTypeIO): if TYPE_CHECKING: Type = ModelPatcher +@comfytype(io_type="BACKGROUND_REMOVAL") +class BackgroundRemoval(ComfyTypeIO): + if TYPE_CHECKING: + Type = BackgroundRemovalModel + @comfytype(io_type="CLIP_VISION") class ClipVision(ComfyTypeIO): if TYPE_CHECKING: @@ -2257,6 +2263,7 @@ __all__ = [ "ModelPatch", "ClipVision", "ClipVisionOutput", + "BackgroundRemoval", "AudioEncoder", "AudioEncoderOutput", "StyleModel", diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py new file mode 100644 index 000000000..8d046b8d4 --- /dev/null +++ b/comfy_extras/nodes_bg_removal.py @@ -0,0 +1,60 @@ +import folder_paths +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO +from comfy.bg_removal_model import load + + +class LoadBackgroundRemovalModel(IO.ComfyNode): + @classmethod + def define_schema(cls): + files = folder_paths.get_filename_list("background_removal") + return IO.Schema( + node_id="LoadBackgroundRemovalModel", + display_name="Load Background Removal Model", + category="loaders", + inputs=[ + IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"), + ], + outputs=[ + IO.BackgroundRemoval.Output("bg_model") + ] + ) + @classmethod + def execute(cls, bg_removal_name): + path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name) + bg = load(path) + if bg is None: + raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.") + return IO.NodeOutput(bg) + +class RemoveBackground(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RemoveBackground", + display_name="Remove Background", + category="image/background removal", + inputs=[ + IO.Image.Input("image", tooltip="Input image to remove the background from"), + IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask") + ], + outputs=[ + IO.Mask.Output("mask", tooltip="Generated foreground mask") + ] + ) + @classmethod + def execute(cls, image, bg_removal_model): + mask = bg_removal_model.encode_image(image) + return IO.NodeOutput(mask) + +class BackgroundRemovalExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LoadBackgroundRemovalModel, + RemoveBackground + ] + + +async def comfy_entrypoint() -> BackgroundRemovalExtension: + return BackgroundRemovalExtension() diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 43a933dac..c9b2a84d9 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou inverse_mask = torch.ones_like(mask) - mask - source_portion = mask * source[..., :visible_height, :visible_width] - destination_portion = inverse_mask * destination[..., top:bottom, left:right] + source_rgb = source[:, :3, :visible_height, :visible_width] + dest_slice = destination[..., top:bottom, left:right] + + if destination.shape[1] == 4: + if torch.max(dest_slice) == 0: + destination[:, :3, top:bottom, left:right] = source_rgb + destination[:, 3:4, top:bottom, left:right] = mask + else: + destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3]) + destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4]) + else: + source_portion = mask * source_rgb + destination_portion = inverse_mask * dest_slice + destination[..., top:bottom, left:right] = source_portion + destination_portion - destination[..., top:bottom, left:right] = source_portion + destination_portion return destination class LatentCompositeMasked(IO.ComfyNode): @@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode): display_name="Image Composite Masked", category="image", inputs=[ - IO.Image.Input("destination"), IO.Image.Input("source"), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Boolean.Input("resize_source", default=False), + IO.Image.Input("destination", optional=True), IO.Mask.Input("mask", optional=True), ], outputs=[IO.Image.Output()], ) @classmethod - def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: + def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput: + if destination is None: # transparent rgba + B, H, W, C = source.shape + destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device) + if C == 3: + source = torch.nn.functional.pad(source, (0, 1), value=1.0) destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) @@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode): expand_mask = execute # TODO: remove - class ThresholdMask(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/folder_paths.py b/folder_paths.py index 98d3b1880..92e8df3cf 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions) + folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) diff --git a/models/background_removal/put_background_removal_models_here b/models/background_removal/put_background_removal_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index ae9e70cb9..5755f0bb8 100644 --- a/nodes.py +++ b/nodes.py @@ -2429,6 +2429,7 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", + "nodes_bg_removal.py", "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", From 05cd076bc1d9386ec77414c96d1460008f653f7c Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 8 May 2026 08:48:59 -0600 Subject: [PATCH 094/102] fix: Make LTXVAddGuide center-crop guide images to match other LTXV nodes (#13794) --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index ab1359fdb..f1f4d5319 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -236,7 +236,7 @@ class LTXVAddGuide(io.ComfyNode): def encode(cls, vae, latent_width, latent_height, images, scale_factors): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] - pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t From 9864f5ac86778221d730c9626952a1ee15c16994 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 8 May 2026 09:02:17 -0600 Subject: [PATCH 095/102] fix: Stop LTXVImgToVideoInplace from mutating input latents and dropping noise_mask (#13793) --- comfy_extras/nodes_lt.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index f1f4d5319..a4c85db77 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode): if bypass: return (latent,) - samples = latent["samples"] + samples = latent["samples"].clone() _, height_scale_factor, width_scale_factor = ( vae.downscale_index_formula ) - batch, _, latent_frames, latent_height, latent_width = samples.shape + _, _, _, latent_height, latent_width = samples.shape width = latent_width * width_scale_factor height = latent_height * height_scale_factor @@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): samples[:, :, :t.shape[2]] = t - conditioning_latent_frames_mask = torch.ones( - (batch, 1, latent_frames, 1, 1), - dtype=torch.float32, - device=samples.device, - ) + conditioning_latent_frames_mask = get_noise_mask(latent) conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) From c5ecd231a2aa41124ec6a958416d166d7dcb81fb Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Fri, 8 May 2026 23:06:29 +0800 Subject: [PATCH 096/102] fix: Fix bug when mask not on same device (CORE-181) (#13801) --- comfy/bg_removal_model.py | 2 +- comfy_extras/nodes_compositing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py index cb7c2ee53..7877afd7f 100644 --- a/comfy/bg_removal_model.py +++ b/comfy/bg_removal_model.py @@ -47,7 +47,7 @@ class BackgroundRemovalModel(): out = self.model(pixel_values=pixel_values) out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) - mask = out.sigmoid() + mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) if mask.ndim == 3: mask = mask.unsqueeze(0) if mask.shape[1] != 1: diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 5b4423734..720efc629 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode): @classmethod def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: batch_size = max(len(image), len(alpha)) - alpha = 1.0 - resize_mask(alpha, image.shape[1:]) + alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:]) alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size) image = comfy.utils.repeat_to_batch_size(image, batch_size) return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1)) From 87878f354f4d49446ed81b5ebfb98b12dda37c7c Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 8 May 2026 12:39:16 -0700 Subject: [PATCH 097/102] Add cloud-runtime FE-facing operations to spec (#13734) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add cloud-runtime FE-facing operations to openapi.yaml Add ~67 cloud-runtime FE-facing path operations to the core OpenAPI spec, each tagged with x-runtime: [cloud] at the operation level. These operations are served by the cloud runtime; the local runtime returns 404 for all of these paths. Domain groups added: - Jobs / prompts: /api/job/*, /api/jobs/*/cancel, /api/prompt/*, etc. - History v2: /api/history_v2, /api/history_v2/{prompt_id} - Cloud logs: /api/logs - Asset extensions: /api/assets/download, export, import, etc. - Custom nodes: /api/experiment/nodes (cloud install/uninstall) - Hub: /api/hub/profiles, /api/hub/workflows, /api/hub/labels, etc. - Workflows: /api/workflows CRUD, versioning, fork, publish - Auth/session: /api/auth/session, /api/auth/token, /.well-known/jwks.json - Billing: /api/billing/balance, plans, subscribe, topup, etc. - Workspace: /api/workspace/*, /api/workspaces/* - User/settings/misc: /api/user, /api/secrets, /api/feedback, etc. Also adds corresponding cloud-only component schemas (CloudJob, CloudWorkflow, BillingPlan, Workspace, HubProfile, AuthSession, etc.), all tagged with x-runtime: [cloud]. Spectral lint passes under the existing ruleset with zero new warnings. * Add job_id field to Asset schema and deprecate prompt_id (#13736) - Add job_id as a nullable UUID field to the Asset schema - Mark prompt_id as deprecated with note pointing to job_id - No x-runtime tag needed as both runtimes populate the field * Add hash field to Asset schemas and deprecate asset_hash (#13738) - Add 'hash' as a nullable string field to Asset and AssetUpdated schemas - Mark 'asset_hash' as deprecated with a note pointing to 'hash' - AssetCreated inherits 'hash' via allOf from Asset - Spectral lint clean (no new warnings) * Fix method drift on cloud-runtime endpoints Three PUT operations were added that should be PATCH (cloud serves PATCH for partial updates): - /api/workflows/{workflow_id} - /api/workspaces/{id} - /api/workspace/members/{userId} Two POST operations were added that should be GET (cloud serves GET with query params): - /api/assets/remote-metadata (url moves to query param) - /api/files/mask-layers (response shape replaced — operation queries related mask layer filenames, not file uploads) * Add missing cloud-runtime operations and schemas PR review surfaced operations the cloud runtime serves that weren't covered by the initial spec push, plus one path family missed entirely. New methods on existing paths: - /api/auth/session: add POST (create session cookie) and DELETE (logout) - /api/secrets/{id}: add GET (read metadata) and PATCH (update) - /api/hub/profiles: add POST (create profile) - /api/hub/workflows: add POST (publish to hub) - /api/hub/workflows/{share_id}: add DELETE (unpublish) - /api/workspaces/{id}: add DELETE (soft-delete workspace) - /api/workspace/members/{user_id}/api-keys: add DELETE (bulk revoke) - /api/workflows/{workflow_id}/versions: add POST (create new version) - /api/userdata/{file}/publish: add GET (read publish info) New path family: - /api/tasks (GET list) and /api/tasks/{task_id} (GET detail) for the background task framework New component schemas (all tagged x-runtime: [cloud]): CreateSessionResponse, DeleteSessionResponse, UpdateSecretRequest, BulkRevokeAPIKeysResponse, CreateHubProfileRequest, PublishHubWorkflowRequest, HubWorkflowDetail, AssetInfo, CreateWorkflowVersionRequest, WorkflowVersionResponse, WorkflowPublishInfo, TaskEntry, TaskResponse, TasksListResponse. Existing SecretMeta extended with provider and last_used_at fields the cloud runtime actually returns. New tag: task. Spectral lint passes with zero errors. * Add job_id and prompt_id to AssetUpdated schema Mirrors the Asset schema's deprecation pattern: prompt_id is marked deprecated with a description pointing to job_id; job_id is the new preferred field. PUT /api/assets/{id} responses can now carry both fields consistent with the other Asset-returning endpoints. * feat: add width and height fields to Asset schema (#13745) Add nullable integer fields 'width' and 'height' to the Asset schema in openapi.yaml. These expose original image dimensions in pixels for clients that need pre-thumbnail size info. Both fields are null for non-image assets or assets ingested before dimension extraction. Co-authored-by: Matt Miller * Remove /api/job/{job_id} and /api/job/{job_id}/outputs These two paths are not actually served by the cloud runtime — they return 404 with a redirect message pointing callers to the canonical `/api/jobs/{job_id}` (plural). Declaring them with `x-runtime: [cloud]` and a 200 response schema is incorrect. `/api/job/{job_id}/status` stays — it is a real cloud-served endpoint. Also drops the now-orphaned `CloudJob` and `CloudJobOutputs` component schemas. `CloudJobStatus` is retained. --- openapi.yaml | 4716 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 4714 insertions(+), 2 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 29b5f544b..4216c1a6c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -62,6 +62,19 @@ tags: - name: assets description: Asset management (feature-gated behind enable-assets) + - name: auth + description: Authentication and session management (cloud-only) + - name: billing + description: Billing, subscriptions, and payment management (cloud-only) + - name: workspace + description: Workspace and team management (cloud-only) + - name: hub + description: "ComfyUI Hub: profiles, shared workflows, and labels (cloud-only)" + - name: workflows + description: Cloud workflow management and versioning (cloud-only) + - name: task + description: Background task management (cloud-only) + paths: # --------------------------------------------------------------------------- # WebSocket @@ -2056,6 +2069,3449 @@ paths: type: integer description: Number of assets marked as missing + + # =========================================================================== + # Cloud-runtime FE-facing operations + # + # These operations are served by the cloud runtime. The local runtime returns + # 404 for all of these paths. Each operation is tagged x-runtime: [cloud]. + # =========================================================================== + + # --------------------------------------------------------------------------- + # Jobs / prompts (cloud) + # --------------------------------------------------------------------------- + /api/jobs/{job_id}/cancel: + post: + operationId: cancelJob + tags: [queue] + summary: Cancel a running or pending job + description: "[cloud-only] Requests cancellation of a job. If the job is currently executing, execution is interrupted. If it is pending in the queue, it is removed." + x-runtime: [cloud] + parameters: + - name: job_id + in: path + required: true + schema: + type: string + format: uuid + description: The job ID to cancel. + responses: + "200": + description: Cancellation accepted + content: + application/json: + schema: + $ref: "#/components/schemas/CloudJobStatus" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/job/{job_id}/status: + get: + operationId: getCloudJobStatus + tags: [queue] + summary: Get status of a cloud job + description: "[cloud-only] Returns the current execution status of a cloud job." + x-runtime: [cloud] + parameters: + - name: job_id + in: path + required: true + schema: + type: string + format: uuid + description: The job ID to check status for. + responses: + "200": + description: Job status + content: + application/json: + schema: + $ref: "#/components/schemas/CloudJobStatus" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/prompt/{prompt_id}: + get: + operationId: getCloudPrompt + tags: [prompt] + summary: Get a cloud prompt by ID + description: "[cloud-only] Returns the full prompt record for a cloud-executed prompt, including the submitted workflow graph and execution metadata." + x-runtime: [cloud] + parameters: + - name: prompt_id + in: path + required: true + schema: + type: string + format: uuid + description: The prompt ID to fetch. + responses: + "200": + description: Cloud prompt detail + content: + application/json: + schema: + $ref: "#/components/schemas/CloudPrompt" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/history_v2: + get: + operationId: getHistoryV2 + tags: [history] + summary: Get paginated execution history (v2) + description: "[cloud-only] Returns a paginated list of execution history entries in the v2 format, with richer metadata than the legacy history endpoint." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + default: 20 + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + default: 0 + description: Pagination offset + - name: status + in: query + schema: + type: string + description: Filter by execution status + responses: + "200": + description: History list + content: + application/json: + schema: + $ref: "#/components/schemas/HistoryV2Response" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/history_v2/{prompt_id}: + get: + operationId: getHistoryV2ByPromptId + tags: [history] + summary: Get v2 history for a specific prompt + description: "[cloud-only] Returns the v2 history entry for a specific prompt execution." + x-runtime: [cloud] + parameters: + - name: prompt_id + in: path + required: true + schema: + type: string + format: uuid + description: The prompt ID to fetch history for. + responses: + "200": + description: History entry + content: + application/json: + schema: + $ref: "#/components/schemas/HistoryV2Entry" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/logs: + get: + operationId: getCloudLogs + tags: [system] + summary: Get cloud execution logs + description: "[cloud-only] Returns execution logs for the authenticated user's cloud jobs." + x-runtime: [cloud] + parameters: + - name: job_id + in: query + schema: + type: string + description: Filter logs by job ID + - name: limit + in: query + schema: + type: integer + default: 100 + description: Maximum number of log entries + - name: offset + in: query + schema: + type: integer + default: 0 + description: Pagination offset + responses: + "200": + description: Log entries + content: + application/json: + schema: + $ref: "#/components/schemas/CloudLogsResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + # --------------------------------------------------------------------------- + # Assets extensions (cloud) + # --------------------------------------------------------------------------- + /api/assets/download: + post: + operationId: downloadAssets + tags: [assets] + summary: Download assets to cloud runtime + description: "[cloud-only] Initiates a download of one or more assets to the cloud runtime environment. Returns a task ID for tracking download progress via WebSocket." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - assets + properties: + assets: + type: array + items: + $ref: "#/components/schemas/AssetDownloadRequest" + description: Assets to download + responses: + "200": + description: Download initiated + content: + application/json: + schema: + type: object + properties: + task_id: + type: string + description: Task ID for tracking progress via WebSocket + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/assets/export: + post: + operationId: exportAssets + tags: [assets] + summary: Export assets as a downloadable archive + description: "[cloud-only] Initiates a bulk export of assets. Returns a task ID for tracking progress via WebSocket. When complete, the export can be downloaded via the exports endpoint." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - asset_ids + properties: + asset_ids: + type: array + items: + type: string + format: uuid + description: IDs of assets to export + export_name: + type: string + description: Name for the export archive + responses: + "200": + description: Export initiated + content: + application/json: + schema: + type: object + properties: + task_id: + type: string + export_name: + type: string + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/assets/exports/{exportName}: + get: + operationId: getAssetExport + tags: [assets] + summary: Download a completed asset export + description: "[cloud-only] Returns the archive file for a completed asset export." + x-runtime: [cloud] + parameters: + - name: exportName + in: path + required: true + schema: + type: string + description: Name of the export to download + responses: + "200": + description: Export archive file + content: + application/zip: + schema: + type: string + format: binary + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/assets/from-workflow: + post: + operationId: createAssetsFromWorkflow + tags: [assets] + summary: Create asset records from a workflow execution + description: "[cloud-only] Registers output files from a workflow execution as assets in the asset database." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - prompt_id + properties: + prompt_id: + type: string + format: uuid + description: Prompt ID whose outputs should be registered as assets + tags: + type: array + items: + type: string + description: Tags to apply to the created assets + responses: + "201": + description: Assets created + content: + application/json: + schema: + type: object + properties: + assets: + type: array + items: + $ref: "#/components/schemas/Asset" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/assets/import: + post: + operationId: importAssets + tags: [assets] + summary: Import assets from external URLs + description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - imports + properties: + imports: + type: array + items: + $ref: "#/components/schemas/AssetImportRequest" + description: Assets to import + responses: + "200": + description: Import initiated + content: + application/json: + schema: + type: object + properties: + assets: + type: array + items: + $ref: "#/components/schemas/Asset" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/assets/remote-metadata: + get: + operationId: getAssetRemoteMetadata + tags: [assets] + summary: Fetch metadata for a remote asset URL + description: "[cloud-only] Fetches and returns metadata (content type, size, filename) for a remote URL without downloading the full content." + x-runtime: [cloud] + parameters: + - name: url + in: query + required: true + schema: + type: string + format: uri + description: URL to inspect + responses: + "200": + description: Remote metadata + content: + application/json: + schema: + $ref: "#/components/schemas/RemoteAssetMetadata" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + # --------------------------------------------------------------------------- + # Custom nodes / hub (cloud) + # --------------------------------------------------------------------------- + /api/experiment/nodes: + get: + operationId: listCloudNodes + tags: [node] + summary: List installed custom nodes + description: "[cloud-only] Returns the list of custom node packages installed in the cloud runtime." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + responses: + "200": + description: Custom node list + content: + application/json: + schema: + $ref: "#/components/schemas/CloudNodeList" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: installCloudNode + tags: [node] + summary: Install a custom node package + description: "[cloud-only] Installs a custom node package in the cloud runtime by ID or repository URL." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - id + properties: + id: + type: string + description: Node package ID or repository URL + version: + type: string + description: Specific version to install + responses: + "200": + description: Node installed + content: + application/json: + schema: + $ref: "#/components/schemas/CloudNode" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/experiment/nodes/{id}: + get: + operationId: getCloudNode + tags: [node] + summary: Get details of an installed custom node + description: "[cloud-only] Returns details about a specific installed custom node package." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Custom node package ID + responses: + "200": + description: Node detail + content: + application/json: + schema: + $ref: "#/components/schemas/CloudNode" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: uninstallCloudNode + tags: [node] + summary: Uninstall a custom node package + description: "[cloud-only] Removes a custom node package from the cloud runtime." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Custom node package ID + responses: + "204": + description: Node uninstalled + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/assets/upload-url: + post: + operationId: getHubAssetUploadUrl + tags: [hub] + summary: Get a pre-signed upload URL for a hub asset + description: "[cloud-only] Returns a pre-signed URL that can be used to upload an asset file directly to storage." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - filename + - content_type + properties: + filename: + type: string + description: Name of the file to upload + content_type: + type: string + description: MIME type of the file + size: + type: integer + format: int64 + description: File size in bytes + responses: + "200": + description: Upload URL + content: + application/json: + schema: + type: object + properties: + upload_url: + type: string + format: uri + description: Pre-signed upload URL + asset_url: + type: string + format: uri + description: Public URL after upload completes + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/labels: + get: + operationId: listHubLabels + tags: [hub] + summary: List available hub labels + description: "[cloud-only] Returns the list of labels/categories available for tagging hub content." + x-runtime: [cloud] + responses: + "200": + description: Label list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/HubLabel" + + /api/hub/profiles: + get: + operationId: listHubProfiles + tags: [hub] + summary: List hub user profiles + description: "[cloud-only] Returns a paginated list of public hub user profiles." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + - name: search + in: query + schema: + type: string + description: Search by username or display name + responses: + "200": + description: Profile list + content: + application/json: + schema: + type: object + properties: + profiles: + type: array + items: + $ref: "#/components/schemas/HubProfile" + total: + type: integer + has_more: + type: boolean + post: + operationId: createHubProfile + tags: [hub] + summary: Create a Hub profile + description: "[cloud-only] Creates a hub profile for the specified workspace. Username is immutable after creation." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateHubProfileRequest" + responses: + "201": + description: Hub profile created + content: + application/json: + schema: + $ref: "#/components/schemas/HubProfile" + "400": + description: Bad request (e.g. invalid username) + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "409": + description: Username already taken or profile already exists + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/profiles/{username}: + get: + operationId: getHubProfile + tags: [hub] + summary: Get a hub profile by username + description: "[cloud-only] Returns the public hub profile for the given username." + x-runtime: [cloud] + parameters: + - name: username + in: path + required: true + schema: + type: string + description: Hub username + responses: + "200": + description: Profile + content: + application/json: + schema: + $ref: "#/components/schemas/HubProfile" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/profiles/check: + get: + operationId: checkHubProfileUsername + tags: [hub] + summary: Check if a hub username is available + description: "[cloud-only] Returns whether the given username is available for registration." + x-runtime: [cloud] + parameters: + - name: username + in: query + required: true + schema: + type: string + description: Username to check + responses: + "200": + description: Availability result + content: + application/json: + schema: + type: object + properties: + available: + type: boolean + username: + type: string + + /api/hub/profiles/me: + get: + operationId: getMyHubProfile + tags: [hub] + summary: Get the authenticated user's hub profile + description: "[cloud-only] Returns the hub profile of the currently authenticated user." + x-runtime: [cloud] + responses: + "200": + description: Profile + content: + application/json: + schema: + $ref: "#/components/schemas/HubProfile" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + put: + operationId: updateMyHubProfile + tags: [hub] + summary: Update the authenticated user's hub profile + description: "[cloud-only] Updates the hub profile of the currently authenticated user." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + username: + type: string + display_name: + type: string + bio: + type: string + avatar_url: + type: string + format: uri + links: + type: array + items: + type: string + format: uri + responses: + "200": + description: Updated profile + content: + application/json: + schema: + $ref: "#/components/schemas/HubProfile" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "409": + description: Conflict + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/workflows: + get: + operationId: listHubWorkflows + tags: [hub] + summary: List published hub workflows + description: "[cloud-only] Returns a paginated list of publicly shared workflows on the hub." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + - name: sort + in: query + schema: + type: string + description: Sort field (e.g. created_at, likes) + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + - name: search + in: query + schema: + type: string + description: Search by title or description + - name: labels + in: query + schema: + type: string + description: Filter by label IDs (comma-separated) + responses: + "200": + description: Hub workflow list + content: + application/json: + schema: + $ref: "#/components/schemas/HubWorkflowList" + post: + operationId: publishHubWorkflow + tags: [hub] + summary: Publish a workflow to the hub + description: "[cloud-only] Publishes a workflow to the hub with metadata, thumbnail, and sample images." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/PublishHubWorkflowRequest" + responses: + "200": + description: Workflow published to hub + content: + application/json: + schema: + $ref: "#/components/schemas/HubWorkflowDetail" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Workflow or profile not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/workflows/{share_id}: + get: + operationId: getHubWorkflow + tags: [hub] + summary: Get a published hub workflow by share ID + description: "[cloud-only] Returns the full details of a published workflow on the hub." + x-runtime: [cloud] + parameters: + - name: share_id + in: path + required: true + schema: + type: string + description: Workflow share ID + responses: + "200": + description: Hub workflow + content: + application/json: + schema: + $ref: "#/components/schemas/HubWorkflow" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: deleteHubWorkflow + tags: [hub] + summary: Unpublish a workflow from the hub + description: "[cloud-only] Removes a workflow from the hub listing." + x-runtime: [cloud] + parameters: + - name: share_id + in: path + required: true + schema: + type: string + description: Workflow share ID + responses: + "204": + description: Successfully unpublished + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Workflow not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/hub/workflows/index: + get: + operationId: getHubWorkflowIndex + tags: [hub] + summary: Get the hub workflow index + description: "[cloud-only] Returns the lightweight index of all hub workflows for client-side search and navigation." + x-runtime: [cloud] + responses: + "200": + description: Workflow index + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/HubWorkflowIndexEntry" + + # --------------------------------------------------------------------------- + # Workflows (cloud) + # --------------------------------------------------------------------------- + /api/workflows: + get: + operationId: listCloudWorkflows + tags: [workflows] + summary: List cloud workflows + description: "[cloud-only] Returns a paginated list of the authenticated user's cloud workflows." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + - name: sort + in: query + schema: + type: string + description: Sort field + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + - name: search + in: query + schema: + type: string + description: Search by workflow name + responses: + "200": + description: Workflow list + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflowList" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createCloudWorkflow + tags: [workflows] + summary: Create a new cloud workflow + description: "[cloud-only] Creates a new cloud workflow with the provided name and optional initial content." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + properties: + name: + type: string + description: Workflow name + description: + type: string + description: Workflow description + content: + type: object + additionalProperties: true + description: Initial workflow graph JSON + responses: + "201": + description: Workflow created + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflow" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workflows/{workflow_id}: + get: + operationId: getCloudWorkflow + tags: [workflows] + summary: Get a cloud workflow by ID + description: "[cloud-only] Returns the metadata for a cloud workflow." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + responses: + "200": + description: Workflow detail + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflow" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + patch: + operationId: updateCloudWorkflow + tags: [workflows] + summary: Update a cloud workflow + description: "[cloud-only] Updates the metadata (name, description) of an existing cloud workflow." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: + type: string + responses: + "200": + description: Workflow updated + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflow" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: deleteCloudWorkflow + tags: [workflows] + summary: Delete a cloud workflow + description: "[cloud-only] Deletes a cloud workflow and all its versions." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + responses: + "204": + description: Workflow deleted + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workflows/{workflow_id}/content: + get: + operationId: getCloudWorkflowContent + tags: [workflows] + summary: Get the content of a cloud workflow + description: "[cloud-only] Returns the full workflow graph JSON for the latest version of a cloud workflow." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + - name: version_id + in: query + schema: + type: string + description: Specific version ID to fetch + responses: + "200": + description: Workflow content + content: + application/json: + schema: + type: object + additionalProperties: true + description: The full workflow graph JSON + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + put: + operationId: updateCloudWorkflowContent + tags: [workflows] + summary: Update the content of a cloud workflow + description: "[cloud-only] Saves new workflow graph JSON as a new version of the cloud workflow." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + requestBody: + required: true + content: + application/json: + schema: + type: object + additionalProperties: true + description: The workflow graph JSON to save + responses: + "200": + description: Content updated + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflowVersion" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workflows/{workflow_id}/fork: + post: + operationId: forkCloudWorkflow + tags: [workflows] + summary: Fork a cloud workflow + description: "[cloud-only] Creates a copy of a cloud workflow under the authenticated user's account." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID to fork. + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Name for the forked workflow (defaults to original name) + responses: + "201": + description: Forked workflow + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflow" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workflows/{workflow_id}/versions: + get: + operationId: listCloudWorkflowVersions + tags: [workflows] + summary: List versions of a cloud workflow + description: "[cloud-only] Returns the version history of a cloud workflow." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + responses: + "200": + description: Version list + content: + application/json: + schema: + type: object + properties: + versions: + type: array + items: + $ref: "#/components/schemas/CloudWorkflowVersion" + total: + type: integer + has_more: + type: boolean + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createCloudWorkflowVersion + tags: [workflows] + summary: Create a new cloud workflow version + description: "[cloud-only] Creates a new workflow version with updated workflow JSON. Uses optimistic concurrency via base_version." + x-runtime: [cloud] + parameters: + - name: workflow_id + in: path + required: true + schema: + type: string + format: uuid + description: The workflow ID. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateWorkflowVersionRequest" + responses: + "201": + description: Version created + content: + application/json: + schema: + $ref: "#/components/schemas/WorkflowVersionResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden — not the workflow owner + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "409": + description: Version conflict — base_version does not match latest + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workflows/published/{share_id}: + get: + operationId: getPublishedWorkflow + tags: [workflows] + summary: Get a published workflow by share ID + description: "[cloud-only] Returns a publicly published cloud workflow by its share identifier." + x-runtime: [cloud] + parameters: + - name: share_id + in: path + required: true + schema: + type: string + description: The workflow share ID. + responses: + "200": + description: Published workflow + content: + application/json: + schema: + $ref: "#/components/schemas/CloudWorkflow" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + # --------------------------------------------------------------------------- + # Auth / session (cloud) + # --------------------------------------------------------------------------- + /api/auth/session: + get: + operationId: getAuthSession + tags: [auth] + summary: Get the current authentication session + description: "[cloud-only] Returns the current session state for the authenticated user, including user identity and active workspace." + x-runtime: [cloud] + responses: + "200": + description: Session info + content: + application/json: + schema: + $ref: "#/components/schemas/AuthSession" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createAuthSession + tags: [auth] + summary: Create a session cookie + description: "[cloud-only] Creates a session cookie from the bearer token in the Authorization header. Returns a Set-Cookie header with a secure HttpOnly session cookie. Cookie authentication is not allowed for this endpoint." + x-runtime: [cloud] + responses: + "200": + description: Session created + content: + application/json: + schema: + $ref: "#/components/schemas/CreateSessionResponse" + "400": + description: Bad request — invalid or expired ID token + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: deleteAuthSession + tags: [auth] + summary: Delete session cookie (logout) + description: "[cloud-only] Clears the session cookie and optionally revokes the session on the server." + x-runtime: [cloud] + responses: + "200": + description: Session deleted + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteSessionResponse" + + /api/auth/token: + post: + operationId: createAuthToken + tags: [auth] + summary: Exchange credentials for an access token + description: "[cloud-only] Exchanges authentication credentials (e.g. an authorization code) for an access token." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - grant_type + properties: + grant_type: + type: string + enum: [authorization_code, refresh_token] + description: OAuth2 grant type + code: + type: string + description: Authorization code (for authorization_code grant) + refresh_token: + type: string + description: Refresh token (for refresh_token grant) + redirect_uri: + type: string + format: uri + description: Redirect URI used in the authorization request + responses: + "200": + description: Token response + content: + application/json: + schema: + $ref: "#/components/schemas/AuthTokenResponse" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /.well-known/jwks.json: + get: + operationId: getJwks + tags: [auth] + summary: Get JSON Web Key Set + description: "[cloud-only] Returns the JSON Web Key Set (JWKS) used to verify JWTs issued by the cloud authentication service." + x-runtime: [cloud] + responses: + "200": + description: JWKS + content: + application/json: + schema: + $ref: "#/components/schemas/JwksResponse" + + # --------------------------------------------------------------------------- + # Billing (cloud) + # --------------------------------------------------------------------------- + /api/billing/balance: + get: + operationId: getBillingBalance + tags: [billing] + summary: Get current credit balance + description: "[cloud-only] Returns the authenticated user's current credit balance and usage summary." + x-runtime: [cloud] + responses: + "200": + description: Balance info + content: + application/json: + schema: + $ref: "#/components/schemas/BillingBalance" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/events: + get: + operationId: listBillingEvents + tags: [billing] + summary: List billing events + description: "[cloud-only] Returns a paginated list of billing events (charges, credits, refunds) for the authenticated user." + x-runtime: [cloud] + parameters: + - name: limit + in: query + schema: + type: integer + description: Maximum number of results + - name: offset + in: query + schema: + type: integer + description: Pagination offset + - name: type + in: query + schema: + type: string + description: Filter by event type + responses: + "200": + description: Billing events + content: + application/json: + schema: + $ref: "#/components/schemas/BillingEventList" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/ops/{id}: + get: + operationId: getBillingOp + tags: [billing] + summary: Get a billing operation by ID + description: "[cloud-only] Returns details of a specific billing operation." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The billing operation ID. + responses: + "200": + description: Billing operation + content: + application/json: + schema: + $ref: "#/components/schemas/BillingOp" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/payment-portal: + post: + operationId: createPaymentPortalSession + tags: [billing] + summary: Create a payment portal session + description: "[cloud-only] Creates a Stripe customer portal session for managing payment methods and invoices. Returns a URL to redirect the user to." + x-runtime: [cloud] + responses: + "200": + description: Portal session + content: + application/json: + schema: + type: object + properties: + url: + type: string + format: uri + description: Stripe portal URL + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/plans: + get: + operationId: listBillingPlans + tags: [billing] + summary: List available billing plans + description: "[cloud-only] Returns the list of available subscription plans and their pricing." + x-runtime: [cloud] + responses: + "200": + description: Plan list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/BillingPlan" + + /api/billing/preview-subscribe: + post: + operationId: previewSubscription + tags: [billing] + summary: Preview a subscription change + description: "[cloud-only] Returns a preview of what a subscription change would cost, including prorations." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - plan_id + properties: + plan_id: + type: string + description: ID of the plan to preview + responses: + "200": + description: Subscription preview + content: + application/json: + schema: + $ref: "#/components/schemas/SubscriptionPreview" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/status: + get: + operationId: getBillingStatus + tags: [billing] + summary: Get billing status + description: "[cloud-only] Returns the authenticated user's current billing and subscription status." + x-runtime: [cloud] + responses: + "200": + description: Billing status + content: + application/json: + schema: + $ref: "#/components/schemas/BillingStatus" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/subscribe: + post: + operationId: createSubscription + tags: [billing] + summary: Subscribe to a billing plan + description: "[cloud-only] Creates a new subscription to the specified billing plan." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - plan_id + properties: + plan_id: + type: string + description: ID of the plan to subscribe to + payment_method_id: + type: string + description: Stripe payment method ID + responses: + "200": + description: Subscription created + content: + application/json: + schema: + $ref: "#/components/schemas/BillingSubscription" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/subscription/cancel: + post: + operationId: cancelSubscription + tags: [billing] + summary: Cancel the active subscription + description: "[cloud-only] Cancels the authenticated user's active subscription. The subscription remains active until the end of the current billing period." + x-runtime: [cloud] + responses: + "200": + description: Subscription cancelled + content: + application/json: + schema: + $ref: "#/components/schemas/BillingSubscription" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/subscription/resubscribe: + post: + operationId: resubscribe + tags: [billing] + summary: Resubscribe after cancellation + description: "[cloud-only] Reactivates a subscription that was previously cancelled but has not yet expired." + x-runtime: [cloud] + responses: + "200": + description: Subscription reactivated + content: + application/json: + schema: + $ref: "#/components/schemas/BillingSubscription" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/billing/topup: + post: + operationId: topUpCredits + tags: [billing] + summary: Purchase additional credits + description: "[cloud-only] Purchases a one-time credit top-up using the user's payment method on file." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - amount + properties: + amount: + type: integer + description: Number of credits to purchase + responses: + "200": + description: Top-up successful + content: + application/json: + schema: + $ref: "#/components/schemas/BillingBalance" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + # --------------------------------------------------------------------------- + # Workspace (cloud) + # --------------------------------------------------------------------------- + /api/workspace/api-keys: + get: + operationId: listWorkspaceApiKeys + tags: [workspace] + summary: List workspace API keys + description: "[cloud-only] Returns the list of API keys for the current workspace." + x-runtime: [cloud] + responses: + "200": + description: API key list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/WorkspaceApiKey" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createWorkspaceApiKey + tags: [workspace] + summary: Create a workspace API key + description: "[cloud-only] Creates a new API key for the current workspace." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + properties: + name: + type: string + description: Display name for the API key + responses: + "201": + description: API key created + content: + application/json: + schema: + $ref: "#/components/schemas/WorkspaceApiKeyCreated" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/api-keys/{id}: + delete: + operationId: deleteWorkspaceApiKey + tags: [workspace] + summary: Delete a workspace API key + description: "[cloud-only] Revokes and deletes a workspace API key." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The API key ID. + responses: + "204": + description: API key deleted + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/invites: + get: + operationId: listWorkspaceInvites + tags: [workspace] + summary: List pending workspace invites + description: "[cloud-only] Returns the list of pending invitations for the current workspace." + x-runtime: [cloud] + responses: + "200": + description: Invite list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/WorkspaceInvite" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createWorkspaceInvite + tags: [workspace] + summary: Invite a user to the workspace + description: "[cloud-only] Creates an invitation for a user to join the current workspace." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + format: email + description: Email address to invite + role: + type: string + enum: [admin, member] + description: Role to assign + responses: + "201": + description: Invite created + content: + application/json: + schema: + $ref: "#/components/schemas/WorkspaceInvite" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "409": + description: Conflict + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/invites/{inviteId}: + delete: + operationId: deleteWorkspaceInvite + tags: [workspace] + summary: Cancel a workspace invite + description: "[cloud-only] Cancels a pending workspace invitation." + x-runtime: [cloud] + parameters: + - name: inviteId + in: path + required: true + schema: + type: string + description: The invite ID. + responses: + "204": + description: Invite cancelled + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/leave: + post: + operationId: leaveWorkspace + tags: [workspace] + summary: Leave the current workspace + description: "[cloud-only] Removes the authenticated user from the current workspace." + x-runtime: [cloud] + responses: + "204": + description: Left workspace + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/members: + get: + operationId: listWorkspaceMembers + tags: [workspace] + summary: List workspace members + description: "[cloud-only] Returns the list of members in the current workspace." + x-runtime: [cloud] + responses: + "200": + description: Member list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/WorkspaceMember" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/members/{user_id}/api-keys: + get: + operationId: listMemberApiKeys + tags: [workspace] + summary: List API keys for a workspace member + description: "[cloud-only] Returns the API keys belonging to a specific workspace member. Requires admin role." + x-runtime: [cloud] + parameters: + - name: user_id + in: path + required: true + schema: + type: string + description: The member's user ID. + responses: + "200": + description: API key list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/WorkspaceApiKey" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: bulkRevokeMemberApiKeys + tags: [workspace] + summary: Bulk revoke a member's API keys + description: "[cloud-only] Revokes all active API keys for a specific workspace member. Only workspace owners can perform this action." + x-runtime: [cloud] + parameters: + - name: user_id + in: path + required: true + schema: + type: string + minLength: 1 + description: The member's user ID. + responses: + "200": + description: Keys revoked + content: + application/json: + schema: + $ref: "#/components/schemas/BulkRevokeAPIKeysResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden — must be workspace owner + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspace/members/{userId}: + patch: + operationId: updateWorkspaceMember + tags: [workspace] + summary: Update a workspace member's role + description: "[cloud-only] Updates the role of a workspace member. Requires admin role." + x-runtime: [cloud] + parameters: + - name: userId + in: path + required: true + schema: + type: string + description: The member's user ID. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - role + properties: + role: + type: string + enum: [admin, member] + description: New role to assign + responses: + "200": + description: Member updated + content: + application/json: + schema: + $ref: "#/components/schemas/WorkspaceMember" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: removeWorkspaceMember + tags: [workspace] + summary: Remove a member from the workspace + description: "[cloud-only] Removes a member from the current workspace. Requires admin role." + x-runtime: [cloud] + parameters: + - name: userId + in: path + required: true + schema: + type: string + description: The member's user ID. + responses: + "204": + description: Member removed + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspaces: + get: + operationId: listWorkspaces + tags: [workspace] + summary: List workspaces the user belongs to + description: "[cloud-only] Returns the list of workspaces the authenticated user is a member of." + x-runtime: [cloud] + responses: + "200": + description: Workspace list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/Workspace" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createWorkspace + tags: [workspace] + summary: Create a new workspace + description: "[cloud-only] Creates a new workspace. The authenticated user becomes the owner." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + properties: + name: + type: string + description: Workspace name + responses: + "201": + description: Workspace created + content: + application/json: + schema: + $ref: "#/components/schemas/Workspace" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/workspaces/{id}: + get: + operationId: getWorkspace + tags: [workspace] + summary: Get a workspace by ID + description: "[cloud-only] Returns details of a workspace the user is a member of." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The workspace ID. + responses: + "200": + description: Workspace detail + content: + application/json: + schema: + $ref: "#/components/schemas/Workspace" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + patch: + operationId: updateWorkspace + tags: [workspace] + summary: Update workspace settings + description: "[cloud-only] Updates the name or settings of a workspace. Requires admin role." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The workspace ID. + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: New workspace name + responses: + "200": + description: Workspace updated + content: + application/json: + schema: + $ref: "#/components/schemas/Workspace" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: deleteWorkspace + tags: [workspace] + summary: Delete a workspace + description: "[cloud-only] Soft-deletes a workspace. Requires owner role. Personal workspaces cannot be deleted." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The workspace ID. + responses: + "204": + description: Workspace deleted + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Forbidden — must be workspace owner + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + # --------------------------------------------------------------------------- + # User / settings / misc (cloud) + # --------------------------------------------------------------------------- + /api/feedback: + post: + operationId: submitFeedback + tags: [user] + summary: Submit user feedback + description: "[cloud-only] Submits feedback from the user about their experience with the cloud runtime." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - message + properties: + message: + type: string + description: Feedback message + rating: + type: integer + minimum: 1 + maximum: 5 + description: Optional satisfaction rating + context: + type: object + additionalProperties: true + description: Additional context metadata + responses: + "200": + description: Feedback submitted + content: + application/json: + schema: + type: object + properties: + id: + type: string + status: + type: string + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/files/mask-layers: + get: + operationId: getMaskLayers + tags: [assets] + summary: Get related mask layer filenames + description: "[cloud-only] Given a mask file (any of the 4 layers), returns all related mask layer filenames. Used by the mask editor to load the paint, mask, and painted layers when reopening a previously edited mask." + x-runtime: [cloud] + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Hash filename of any mask layer file + responses: + "200": + description: Related mask layers + content: + application/json: + schema: + type: object + properties: + mask: + type: string + description: Filename of the mask layer + nullable: true + paint: + type: string + description: Filename of the paint strokes layer + nullable: true + painted: + type: string + description: Filename of the painted image layer + nullable: true + painted_masked: + type: string + description: Filename of the final composite layer + nullable: true + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: File not found or not a mask file + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/internal/cloud_analytics: + post: + operationId: postCloudAnalytics + tags: [internal] + summary: Post client analytics events + description: "[cloud-only] Receives analytics events from the frontend for processing by the cloud analytics pipeline." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - events + properties: + events: + type: array + items: + type: object + required: + - event_name + properties: + event_name: + type: string + timestamp: + type: string + format: date-time + properties: + type: object + additionalProperties: true + responses: + "200": + description: Events accepted + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/invites/{token}/accept: + post: + operationId: acceptInvite + tags: [workspace] + summary: Accept a workspace invitation + description: "[cloud-only] Accepts a workspace invitation using the invite token. The authenticated user is added to the workspace." + x-runtime: [cloud] + parameters: + - name: token + in: path + required: true + schema: + type: string + description: The invitation token. + responses: + "200": + description: Invite accepted + content: + application/json: + schema: + $ref: "#/components/schemas/Workspace" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/secrets: + get: + operationId: listSecrets + tags: [settings] + summary: List user secrets + description: "[cloud-only] Returns the list of secrets (API keys for third-party services) stored for the authenticated user. Secret values are redacted." + x-runtime: [cloud] + responses: + "200": + description: Secret list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/SecretMeta" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: createSecret + tags: [settings] + summary: Create or update a secret + description: "[cloud-only] Stores a new secret or updates an existing one. Secrets are encrypted at rest." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + - value + properties: + name: + type: string + description: Secret name (unique per user) + value: + type: string + description: Secret value + responses: + "201": + description: Secret created + content: + application/json: + schema: + $ref: "#/components/schemas/SecretMeta" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/secrets/{id}: + get: + operationId: getSecret + tags: [settings] + summary: Get secret metadata + description: "[cloud-only] Returns metadata for a specific secret. Does not return the plaintext secret value." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + format: uuid + description: The secret ID. + responses: + "200": + description: Secret metadata + content: + application/json: + schema: + $ref: "#/components/schemas/SecretMeta" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + patch: + operationId: updateSecret + tags: [settings] + summary: Update a secret + description: "[cloud-only] Updates an existing secret's name and/or value. Both fields are optional; only provided fields are updated." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + format: uuid + description: The secret ID. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateSecretRequest" + responses: + "200": + description: Secret updated + content: + application/json: + schema: + $ref: "#/components/schemas/SecretMeta" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "409": + description: Conflict — a secret with this name already exists + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + delete: + operationId: deleteSecret + tags: [settings] + summary: Delete a secret + description: "[cloud-only] Permanently deletes a stored secret." + x-runtime: [cloud] + parameters: + - name: id + in: path + required: true + schema: + type: string + description: The secret ID. + responses: + "204": + description: Secret deleted + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/user: + get: + operationId: getCloudUser + tags: [user] + summary: Get the authenticated cloud user + description: "[cloud-only] Returns the profile and account information for the currently authenticated user." + x-runtime: [cloud] + responses: + "200": + description: User profile + content: + application/json: + schema: + $ref: "#/components/schemas/CloudUser" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + put: + operationId: updateCloudUser + tags: [user] + summary: Update the authenticated cloud user profile + description: "[cloud-only] Updates the profile information for the currently authenticated user." + x-runtime: [cloud] + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + display_name: + type: string + avatar_url: + type: string + format: uri + responses: + "200": + description: Updated profile + content: + application/json: + schema: + $ref: "#/components/schemas/CloudUser" + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/userdata/{file}/publish: + get: + operationId: getUserdataFilePublish + tags: [userdata] + summary: Get publish info for a userdata file + description: "[cloud-only] Returns the publish status and share info for a userdata workflow file." + x-runtime: [cloud] + parameters: + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "200": + description: Publish info (publish_time is null if never published) + content: + application/json: + schema: + $ref: "#/components/schemas/WorkflowPublishInfo" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Workflow not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: publishUserdataFile + tags: [userdata] + summary: Publish a userdata file to the cloud + description: "[cloud-only] Makes a userdata file available via a public URL for sharing or embedding." + x-runtime: [cloud] + parameters: + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "200": + description: Published file URL + content: + application/json: + schema: + type: object + properties: + url: + type: string + format: uri + description: Public URL of the published file + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/vhs/queryvideo: + get: + operationId: queryVhsVideo + tags: [view] + summary: Query VHS video metadata + description: "[cloud-only] Returns metadata about a video file processed by the VHS (Video Helper Suite) integration." + x-runtime: [cloud] + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Video filename + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + responses: + "200": + description: Video metadata + content: + application/json: + schema: + type: object + additionalProperties: true + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/vhs/viewaudio: + get: + operationId: viewVhsAudio + tags: [view] + summary: View or download VHS audio + description: "[cloud-only] Returns audio content from a VHS-processed file." + x-runtime: [cloud] + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Audio filename + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + responses: + "200": + description: Audio content + content: + audio/*: + schema: + type: string + format: binary + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/vhs/viewvideo: + get: + operationId: viewVhsVideo + tags: [view] + summary: View or download VHS video + description: "[cloud-only] Returns video content from a VHS-processed file." + x-runtime: [cloud] + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Video filename + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + responses: + "200": + description: Video content + content: + video/*: + schema: + type: string + format: binary + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/viewvideo: + get: + operationId: viewVideo + tags: [view] + summary: View or download a video file + description: "[cloud-only] Serves a video file from the output directory. Used by the frontend video player." + x-runtime: [cloud] + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Video filename + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + responses: + "200": + description: Video content + content: + video/*: + schema: + type: string + format: binary + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/tasks: + get: + operationId: listTasks + tags: [task] + summary: List background tasks + description: "[cloud-only] Retrieve a paginated list of background tasks for the authenticated user. Supports filtering by task type, status, and creation time." + x-runtime: [cloud] + parameters: + - name: task_name + in: query + schema: + type: string + description: Filter by task type name (exact match). + - name: idempotency_key + in: query + schema: + type: string + description: Filter by idempotency key (exact match). + - name: status + in: query + schema: + type: string + description: Filter by one or more statuses (comma-separated). + - name: created_after + in: query + schema: + type: string + format: date-time + description: Filter tasks created after this timestamp. + - name: created_before + in: query + schema: + type: string + format: date-time + description: Filter tasks created before this timestamp. + - name: sort_order + in: query + schema: + type: string + enum: [asc, desc] + default: desc + description: Sort direction by create_time. + - name: offset + in: query + schema: + type: integer + minimum: 0 + default: 0 + description: Pagination offset (0-based). + - name: limit + in: query + schema: + type: integer + minimum: 1 + maximum: 100 + default: 20 + description: Maximum items per page (1-100). + responses: + "200": + description: Tasks retrieved + content: + application/json: + schema: + $ref: "#/components/schemas/TasksListResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "422": + description: Validation error + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /api/tasks/{task_id}: + get: + operationId: getTask + tags: [task] + summary: Get task details + description: "[cloud-only] Retrieve full details for a specific background task." + x-runtime: [cloud] + parameters: + - name: task_id + in: path + required: true + schema: + type: string + format: uuid + description: Task identifier (UUID). + responses: + "200": + description: Task details + content: + application/json: + schema: + $ref: "#/components/schemas/TaskResponse" + "401": + description: Unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: Task not found + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + components: parameters: ComfyUserHeader: @@ -2823,14 +6279,29 @@ components: name: type: string description: Name of the asset file + hash: + type: string + nullable: true + description: Blake3 content hash of the asset (preferred over asset_hash) + pattern: "^blake3:[a-f0-9]{64}$" asset_hash: type: string - description: Blake3 hash of the asset content + nullable: true + deprecated: true + description: "Deprecated: use `hash` instead. Blake3 hash of the asset content." pattern: "^blake3:[a-f0-9]{64}$" size: type: integer format: int64 description: Size of the asset in bytes + width: + type: integer + nullable: true + description: "Original image width in pixels. Null for non-image assets or assets ingested before dimension extraction." + height: + type: integer + nullable: true + description: "Original image height in pixels. Null for non-image assets or assets ingested before dimension extraction." mime_type: type: string description: MIME type of the asset @@ -2859,7 +6330,14 @@ components: prompt_id: type: string format: uuid - description: ID of the prompt that created this asset + nullable: true + deprecated: true + description: "Deprecated: use job_id instead. ID of the prompt that created this asset." + job_id: + type: string + format: uuid + nullable: true + description: ID of the job that created this asset created_at: type: string format: date-time @@ -2897,8 +6375,16 @@ components: format: uuid name: type: string + hash: + type: string + nullable: true + description: Blake3 content hash of the asset (preferred over asset_hash) + pattern: "^blake3:[a-f0-9]{64}$" asset_hash: type: string + nullable: true + deprecated: true + description: "Deprecated: use `hash` instead. Blake3 hash of the asset content." pattern: "^blake3:[a-f0-9]{64}$" tags: type: array @@ -2909,6 +6395,17 @@ components: user_metadata: type: object additionalProperties: true + prompt_id: + type: string + format: uuid + nullable: true + deprecated: true + description: "Deprecated: use job_id instead. ID of the prompt that created this asset." + job_id: + type: string + format: uuid + nullable: true + description: ID of the job that created this asset updated_at: type: string format: date-time @@ -3365,3 +6862,1218 @@ components: enum: [created, running, completed, failed] error: type: string + + + # ------------------------------------------------------------------- + # Cloud-runtime schemas + # + # These schemas are exclusively referenced by cloud-runtime operations. + # Tagged x-runtime: [cloud]. + # ------------------------------------------------------------------- + CloudError: + type: object + x-runtime: [cloud] + description: "[cloud-only] Standard error response from cloud endpoints." + required: + - error + properties: + error: + type: string + description: Error message + code: + type: string + description: Machine-readable error code + details: + type: object + additionalProperties: true + description: Additional error context + + CloudJobStatus: + type: object + x-runtime: [cloud] + description: "[cloud-only] Status of a cloud job." + required: + - id + - status + properties: + id: + type: string + format: uuid + status: + type: string + enum: [pending, running, completed, failed, cancelled] + progress: + type: number + minimum: 0 + maximum: 1 + description: "Execution progress (0.0 to 1.0)" + started_at: + type: string + format: date-time + nullable: true + completed_at: + type: string + format: date-time + nullable: true + + CloudPrompt: + type: object + x-runtime: [cloud] + description: "[cloud-only] A cloud-executed prompt record." + required: + - id + - status + properties: + id: + type: string + format: uuid + status: + type: string + workflow: + type: object + additionalProperties: true + outputs: + type: object + additionalProperties: true + created_at: + type: string + format: date-time + completed_at: + type: string + format: date-time + nullable: true + + HistoryV2Response: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated execution history in v2 format." + required: + - items + - total + - has_more + properties: + items: + type: array + items: + $ref: "#/components/schemas/HistoryV2Entry" + total: + type: integer + has_more: + type: boolean + + HistoryV2Entry: + type: object + x-runtime: [cloud] + description: "[cloud-only] A single execution history entry in v2 format." + required: + - id + - status + properties: + id: + type: string + format: uuid + status: + type: string + workflow: + type: object + additionalProperties: true + outputs: + type: object + additionalProperties: true + created_at: + type: string + format: date-time + started_at: + type: string + format: date-time + nullable: true + completed_at: + type: string + format: date-time + nullable: true + preview_output: + type: object + additionalProperties: true + + CloudLogsResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated cloud execution logs." + required: + - entries + properties: + entries: + type: array + items: + type: object + properties: + timestamp: + type: string + format: date-time + level: + type: string + enum: [debug, info, warn, error] + message: + type: string + job_id: + type: string + format: uuid + total: + type: integer + has_more: + type: boolean + + AssetDownloadRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] A single asset to download to the cloud runtime." + required: + - asset_id + properties: + asset_id: + type: string + format: uuid + description: ID of the asset to download + target_path: + type: string + description: Target path on the runtime filesystem + + AssetImportRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] A single asset to import from an external URL." + required: + - url + properties: + url: + type: string + format: uri + description: URL of the asset to import + name: + type: string + description: Display name for the imported asset + tags: + type: array + items: + type: string + + RemoteAssetMetadata: + type: object + x-runtime: [cloud] + description: "[cloud-only] Metadata fetched from a remote asset URL." + properties: + content_type: + type: string + description: MIME type of the remote file + content_length: + type: integer + format: int64 + description: Size in bytes + filename: + type: string + description: Suggested filename from Content-Disposition or URL + + CloudNode: + type: object + x-runtime: [cloud] + description: "[cloud-only] An installed custom node package in the cloud runtime." + required: + - id + - name + properties: + id: + type: string + name: + type: string + version: + type: string + description: + type: string + author: + type: string + repository: + type: string + format: uri + installed_at: + type: string + format: date-time + enabled: + type: boolean + + CloudNodeList: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated list of installed custom node packages." + required: + - nodes + properties: + nodes: + type: array + items: + $ref: "#/components/schemas/CloudNode" + total: + type: integer + has_more: + type: boolean + + HubLabel: + type: object + x-runtime: [cloud] + description: "[cloud-only] A label/category used for tagging hub content." + required: + - id + - name + properties: + id: + type: string + name: + type: string + description: + type: string + color: + type: string + description: Hex color code for the label + + HubProfile: + type: object + x-runtime: [cloud] + description: "[cloud-only] A public user profile on the ComfyUI Hub." + required: + - username + properties: + username: + type: string + display_name: + type: string + bio: + type: string + avatar_url: + type: string + format: uri + links: + type: array + items: + type: string + format: uri + workflow_count: + type: integer + created_at: + type: string + format: date-time + + HubWorkflow: + type: object + x-runtime: [cloud] + description: "[cloud-only] A published workflow on the ComfyUI Hub." + required: + - share_id + - name + properties: + share_id: + type: string + name: + type: string + description: + type: string + author: + $ref: "#/components/schemas/HubProfile" + labels: + type: array + items: + $ref: "#/components/schemas/HubLabel" + thumbnail_url: + type: string + format: uri + content: + type: object + additionalProperties: true + description: Workflow graph JSON + likes: + type: integer + views: + type: integer + forks: + type: integer + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + HubWorkflowList: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated list of hub workflows." + required: + - workflows + - total + - has_more + properties: + workflows: + type: array + items: + $ref: "#/components/schemas/HubWorkflow" + total: + type: integer + has_more: + type: boolean + + HubWorkflowIndexEntry: + type: object + x-runtime: [cloud] + description: "[cloud-only] Lightweight entry in the hub workflow index for client-side search." + required: + - share_id + - name + properties: + share_id: + type: string + name: + type: string + author_username: + type: string + labels: + type: array + items: + type: string + likes: + type: integer + updated_at: + type: string + format: date-time + + CloudWorkflow: + type: object + x-runtime: [cloud] + description: "[cloud-only] A cloud-managed workflow with version history." + required: + - id + - name + properties: + id: + type: string + format: uuid + name: + type: string + description: + type: string + share_id: + type: string + nullable: true + description: Public share identifier if published + latest_version_id: + type: string + format: uuid + nullable: true + thumbnail_url: + type: string + format: uri + nullable: true + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + CloudWorkflowList: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated list of cloud workflows." + required: + - workflows + - total + - has_more + properties: + workflows: + type: array + items: + $ref: "#/components/schemas/CloudWorkflow" + total: + type: integer + has_more: + type: boolean + + CloudWorkflowVersion: + type: object + x-runtime: [cloud] + description: "[cloud-only] A version of a cloud workflow." + required: + - id + - workflow_id + properties: + id: + type: string + format: uuid + workflow_id: + type: string + format: uuid + version_number: + type: integer + created_at: + type: string + format: date-time + + AuthSession: + type: object + x-runtime: [cloud] + description: "[cloud-only] Current authentication session state." + required: + - user + properties: + user: + $ref: "#/components/schemas/CloudUser" + workspace: + $ref: "#/components/schemas/Workspace" + expires_at: + type: string + format: date-time + + AuthTokenResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] OAuth2 token response." + required: + - access_token + - token_type + properties: + access_token: + type: string + token_type: + type: string + description: Always "Bearer" + expires_in: + type: integer + description: Token lifetime in seconds + refresh_token: + type: string + nullable: true + scope: + type: string + + JwksResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] JSON Web Key Set for JWT verification." + required: + - keys + properties: + keys: + type: array + items: + type: object + required: + - kty + - kid + - use + properties: + kty: + type: string + description: Key type (e.g. RSA) + kid: + type: string + description: Key ID + use: + type: string + description: Key use (e.g. sig) + alg: + type: string + description: Algorithm (e.g. RS256) + n: + type: string + description: RSA modulus (base64url) + e: + type: string + description: RSA exponent (base64url) + additionalProperties: true + + BillingBalance: + type: object + x-runtime: [cloud] + description: "[cloud-only] Current credit balance and usage summary." + required: + - credits_remaining + properties: + credits_remaining: + type: integer + description: Available credits + credits_used: + type: integer + description: Credits used in current billing period + credits_total: + type: integer + description: Total credits allocated in current period + + BillingEvent: + type: object + x-runtime: [cloud] + description: "[cloud-only] A billing event (charge, credit, refund)." + required: + - id + - type + - amount + - created_at + properties: + id: + type: string + type: + type: string + enum: [charge, credit, refund, topup, subscription] + amount: + type: integer + description: Amount in credits + description: + type: string + job_id: + type: string + format: uuid + nullable: true + created_at: + type: string + format: date-time + + BillingEventList: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated list of billing events." + required: + - events + - total + - has_more + properties: + events: + type: array + items: + $ref: "#/components/schemas/BillingEvent" + total: + type: integer + has_more: + type: boolean + + BillingOp: + type: object + x-runtime: [cloud] + description: "[cloud-only] A billing operation record." + required: + - id + - status + properties: + id: + type: string + status: + type: string + enum: [pending, completed, failed] + type: + type: string + amount: + type: integer + created_at: + type: string + format: date-time + completed_at: + type: string + format: date-time + nullable: true + + BillingPlan: + type: object + x-runtime: [cloud] + description: "[cloud-only] A subscription plan with pricing details." + required: + - id + - name + properties: + id: + type: string + name: + type: string + description: + type: string + credits_per_month: + type: integer + price_cents: + type: integer + description: Monthly price in cents (USD) + currency: + type: string + default: usd + features: + type: array + items: + type: string + description: List of plan features + + BillingStatus: + type: object + x-runtime: [cloud] + description: "[cloud-only] Overall billing and subscription status." + properties: + subscription: + $ref: "#/components/schemas/BillingSubscription" + balance: + $ref: "#/components/schemas/BillingBalance" + has_payment_method: + type: boolean + + BillingSubscription: + type: object + x-runtime: [cloud] + description: "[cloud-only] Active subscription details." + required: + - id + - status + - plan_id + properties: + id: + type: string + status: + type: string + enum: [active, cancelled, past_due, trialing] + plan_id: + type: string + plan_name: + type: string + current_period_start: + type: string + format: date-time + current_period_end: + type: string + format: date-time + cancel_at_period_end: + type: boolean + + SubscriptionPreview: + type: object + x-runtime: [cloud] + description: "[cloud-only] Preview of a subscription change including prorations." + properties: + plan_id: + type: string + plan_name: + type: string + amount_due: + type: integer + description: Amount due in cents + proration_amount: + type: integer + description: Proration adjustment in cents + currency: + type: string + next_billing_date: + type: string + format: date-time + + Workspace: + type: object + x-runtime: [cloud] + description: "[cloud-only] A cloud workspace for team collaboration." + required: + - id + - name + properties: + id: + type: string + name: + type: string + owner_id: + type: string + member_count: + type: integer + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + WorkspaceMember: + type: object + x-runtime: [cloud] + description: "[cloud-only] A member of a cloud workspace." + required: + - user_id + - role + properties: + user_id: + type: string + email: + type: string + format: email + display_name: + type: string + avatar_url: + type: string + format: uri + role: + type: string + enum: [owner, admin, member] + joined_at: + type: string + format: date-time + + WorkspaceInvite: + type: object + x-runtime: [cloud] + description: "[cloud-only] A pending workspace invitation." + required: + - id + - email + - role + properties: + id: + type: string + email: + type: string + format: email + role: + type: string + enum: [admin, member] + invited_by: + type: string + created_at: + type: string + format: date-time + expires_at: + type: string + format: date-time + + WorkspaceApiKey: + type: object + x-runtime: [cloud] + description: "[cloud-only] A workspace API key (secret value redacted)." + required: + - id + - name + properties: + id: + type: string + name: + type: string + prefix: + type: string + description: First few characters of the key for identification + created_at: + type: string + format: date-time + last_used_at: + type: string + format: date-time + nullable: true + created_by: + type: string + + WorkspaceApiKeyCreated: + type: object + x-runtime: [cloud] + description: "[cloud-only] A newly created workspace API key, including the full secret value (shown only once)." + required: + - id + - name + - key + properties: + id: + type: string + name: + type: string + key: + type: string + description: Full API key value (only returned on creation) + prefix: + type: string + created_at: + type: string + format: date-time + + CloudUser: + type: object + x-runtime: [cloud] + description: "[cloud-only] A cloud-authenticated user profile." + required: + - id + - email + properties: + id: + type: string + email: + type: string + format: email + display_name: + type: string + avatar_url: + type: string + format: uri + created_at: + type: string + format: date-time + + SecretMeta: + type: object + x-runtime: [cloud] + description: "[cloud-only] Metadata for a stored secret (value is never returned)." + required: + - id + - name + properties: + id: + type: string + name: + type: string + provider: + type: string + description: "[cloud-only] Provider identifier (e.g., huggingface, civitai)." + x-runtime: [cloud] + last_used_at: + type: string + format: date-time + description: "[cloud-only] When the secret was last used for decryption." + x-runtime: [cloud] + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + UpdateSecretRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] Request body for updating an existing user secret." + properties: + name: + type: string + description: New name for the secret + secret_value: + type: string + description: New secret value (API key, token, etc.) + + CreateSessionResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Response after creating a session cookie." + required: + - success + properties: + success: + type: boolean + expiresIn: + type: integer + description: Session expiration time in seconds. + + DeleteSessionResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Response after deleting a session cookie." + required: + - success + properties: + success: + type: boolean + + CreateHubProfileRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] Request body for creating a new Hub profile." + required: + - workspace_id + - username + properties: + workspace_id: + type: string + username: + type: string + description: Unique URL-safe slug. Immutable after creation. + display_name: + type: string + description: + type: string + avatar_token: + type: string + website_urls: + type: array + items: + type: string + + PublishHubWorkflowRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] Request body for publishing or updating a workflow on the Hub." + required: + - username + - name + - workflow_filename + - asset_ids + properties: + username: + type: string + name: + type: string + workflow_filename: + type: string + asset_ids: + type: array + items: + type: string + description: + type: string + tags: + type: array + items: + type: string + models: + type: array + items: + type: string + custom_nodes: + type: array + items: + type: string + tutorial_url: + type: string + metadata: + type: object + additionalProperties: true + thumbnail_type: + type: string + enum: [image, video, image_comparison] + thumbnail_token_or_url: + type: string + thumbnail_comparison_token_or_url: + type: string + sample_image_tokens_or_urls: + type: array + items: + type: string + + HubWorkflowDetail: + type: object + x-runtime: [cloud] + description: "[cloud-only] Full Hub workflow detail including versions, assets, and statistics." + required: + - share_id + - workflow_id + - name + - workflow_json + - assets + - profile + - status + properties: + share_id: + type: string + workflow_id: + type: string + name: + type: string + status: + type: string + enum: [pending, approved, rejected, deprecated] + description: + type: string + thumbnail_type: + type: string + enum: [image, video, image_comparison] + thumbnail_url: + type: string + thumbnail_comparison_url: + type: string + tutorial_url: + type: string + metadata: + type: object + additionalProperties: true + sample_image_urls: + type: array + items: + type: string + publish_time: + type: string + format: date-time + nullable: true + workflow_json: + type: object + additionalProperties: true + assets: + type: array + items: + $ref: "#/components/schemas/AssetInfo" + profile: + $ref: "#/components/schemas/HubProfile" + + AssetInfo: + type: object + x-runtime: [cloud] + description: "[cloud-only] Lightweight asset reference used in workflow publishing payloads." + required: + - id + - filename + properties: + id: + type: string + filename: + type: string + mime_type: + type: string + size_bytes: + type: integer + format: int64 + + BulkRevokeAPIKeysResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Response after bulk-revoking API keys for a workspace member." + required: + - revoked_count + properties: + revoked_count: + type: integer + minimum: 0 + + CreateWorkflowVersionRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] Request body for creating a new version of a saved workflow." + required: + - base_version + - workflow_json + properties: + base_version: + type: integer + description: Version number this change is based on (for optimistic concurrency). + workflow_json: + type: object + additionalProperties: true + + WorkflowVersionResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Metadata for a single workflow version." + required: + - id + - version + - latest_version + - created_by + - created_at + properties: + id: + type: string + version: + type: integer + latest_version: + type: integer + created_by: + type: string + created_at: + type: string + format: date-time + + WorkflowPublishInfo: + type: object + x-runtime: [cloud] + description: "[cloud-only] Publishing metadata for a workflow shared to the Hub." + required: + - workflow_id + - share_id + - listed + - assets + properties: + workflow_id: + type: string + share_id: + type: string + publish_time: + type: string + format: date-time + nullable: true + listed: + type: boolean + assets: + type: array + items: + $ref: "#/components/schemas/AssetInfo" + + TaskEntry: + type: object + x-runtime: [cloud] + description: "[cloud-only] Task data for list views." + required: + - id + - task_name + - status + - create_time + properties: + id: + type: string + format: uuid + task_name: + type: string + status: + type: string + enum: [created, running, completed, failed] + create_time: + type: string + format: date-time + started_at: + type: string + format: date-time + completed_at: + type: string + format: date-time + + TaskResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Full task details including payload and result." + required: + - id + - idempotency_key + - task_name + - payload + - status + - create_time + - update_time + properties: + id: + type: string + format: uuid + idempotency_key: + type: string + task_name: + type: string + payload: + type: object + additionalProperties: true + status: + type: string + enum: [created, running, completed, failed] + result: + type: object + additionalProperties: true + create_time: + type: string + format: date-time + update_time: + type: string + format: date-time + started_at: + type: string + format: date-time + completed_at: + type: string + format: date-time + error: + type: string + + TasksListResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Paginated list of background tasks for the authenticated user." + required: + - tasks + - pagination + properties: + tasks: + type: array + items: + $ref: "#/components/schemas/TaskEntry" + pagination: + $ref: "#/components/schemas/PaginationInfo" \ No newline at end of file From 65045730a60af0bf75cec2a738555a952da2ea4e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 8 May 2026 23:11:52 +0300 Subject: [PATCH 098/102] [Partner Nodes] additionally use Baidu server to detect the accessibility of internet (#13803) Signed-off-by: bigcat88 --- comfy_api_nodes/util/client.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 8e1ba91ba..052301c33 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]: "api_accessible": False, } timeout = aiohttp.ClientTimeout(total=5.0) + + # Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required + # to correctly detect that Chinese users with working internet do have working internet. + internet_probe_urls = ("https://www.google.com", "https://www.baidu.com") + async with aiohttp.ClientSession(timeout=timeout) as session: - with contextlib.suppress(ClientError, OSError): - async with session.get("https://www.google.com") as resp: - results["internet_accessible"] = resp.status < 500 + async def _probe(url: str) -> bool: + try: + async with session.get(url) as resp: + return resp.status < 500 + except (ClientError, OSError, asyncio.TimeoutError): + return False + + probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls] + try: + for fut in asyncio.as_completed(probe_tasks): + if await fut: + results["internet_accessible"] = True + break + finally: + for t in probe_tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*probe_tasks, return_exceptions=True) if not results["internet_accessible"]: return results From 66669b2ded7d8f362fdf64bb1c77a8df0f684e2f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 8 May 2026 17:32:14 -0700 Subject: [PATCH 099/102] I don't think there was any because nobody complained. (#13807) --- comfy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 7b7faad3a..91e1ba3d3 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): k_out = "{}.weight_scale".format(layer) if layer is not None: - layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + layer_conf = {"format": "float8_e4m3fn"} if full_precision_matrix_mult: layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult layers[layer] = layer_conf From 4e823431cc8291deced4fc2dcf3967be2549e4c0 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 8 May 2026 19:14:23 -0700 Subject: [PATCH 100/102] Add cloud-runtime experiment node-schema endpoints to spec (#13806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add cloud-runtime experiment node-schema endpoints to spec Replace the GET operations at /api/experiment/nodes and /api/experiment/nodes/{id} with getNodeInfoSchema and getNodeByID — the optimized, ETag-tagged object_info schema endpoints the cloud frontend depends on for the workflow editor. Each operation is tagged x-runtime: [cloud] and uses the runtime-only tag for cloud-side codegen exclusion. Response headers document the ETag and Cache-Control validators; 304 Not Modified is declared for RFC 7232 conditional GETs. Remove the now-unused CloudNodeList schema to keep Spectral clean. Co-authored-by: Matt Miller * spec: document If-None-Match header on conditional GET endpoints Both `getNodeInfoSchema` and `getNodeByID` advertise `ETag` response headers and a `304 Not Modified` response, but the spec didn't declare the `If-None-Match` request header that triggers conditional validation. Adding it as an optional header parameter on both ops so client codegen exposes the conditional-GET pattern. --- openapi.yaml | 106 +++++++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 55 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 4216c1a6c..d4c9e67ca 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -74,6 +74,8 @@ tags: description: Cloud workflow management and versioning (cloud-only) - name: task description: Background task management (cloud-only) + - name: runtime-only + description: Operations served exclusively by the cloud runtime with no local equivalent paths: # --------------------------------------------------------------------------- @@ -2573,35 +2575,38 @@ paths: # --------------------------------------------------------------------------- /api/experiment/nodes: get: - operationId: listCloudNodes - tags: [node] - summary: List installed custom nodes - description: "[cloud-only] Returns the list of custom node packages installed in the cloud runtime." + operationId: getNodeInfoSchema + tags: [runtime-only] + summary: Get pre-rendered node info schema + description: "[cloud-only] Returns the static ComfyUI object_info schema, identical for every caller, rendered once at startup with empty model/user-file context. Served by a raw HTTP handler that writes pre-rendered bytes with ETag + Cache-Control validators for RFC 7232 conditional GETs." x-runtime: [cloud] parameters: - - name: limit - in: query + - name: If-None-Match + in: header + required: false schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset + type: string + description: Entity tag previously returned by this endpoint. When present and matching, the server returns 304 Not Modified. responses: "200": - description: Custom node list + description: Node info schema + headers: + ETag: + schema: + type: string + description: Entity tag for conditional request validation + Cache-Control: + schema: + type: string + description: Cache directives for the response content: application/json: schema: - $ref: "#/components/schemas/CloudNodeList" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" + type: object + additionalProperties: + $ref: "#/components/schemas/NodeInfo" + "304": + description: Not Modified — returned when the client sends a matching If-None-Match header post: operationId: installCloudNode tags: [node] @@ -2651,10 +2656,10 @@ paths: /api/experiment/nodes/{id}: get: - operationId: getCloudNode - tags: [node] - summary: Get details of an installed custom node - description: "[cloud-only] Returns details about a specific installed custom node package." + operationId: getNodeByID + tags: [runtime-only] + summary: Get a single node definition by ID + description: "[cloud-only] Returns one node's definition from the pre-indexed object_info schema. Served by a raw HTTP handler that writes pre-rendered bytes with ETag + Cache-Control validators for RFC 7232 conditional GETs." x-runtime: [cloud] parameters: - name: id @@ -2662,26 +2667,33 @@ paths: required: true schema: type: string - description: Custom node package ID + description: Node class identifier + - name: If-None-Match + in: header + required: false + schema: + type: string + description: Entity tag previously returned by this endpoint. When present and matching, the server returns 304 Not Modified. responses: "200": - description: Node detail + description: Single node definition + headers: + ETag: + schema: + type: string + description: Entity tag for conditional request validation + Cache-Control: + schema: + type: string + description: Cache directives for the response content: application/json: schema: - $ref: "#/components/schemas/CloudNode" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" + $ref: "#/components/schemas/NodeInfo" + "304": + description: Not Modified — returned when the client sends a matching If-None-Match header "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" + description: Node not found delete: operationId: uninstallCloudNode tags: [node] @@ -7100,22 +7112,6 @@ components: enabled: type: boolean - CloudNodeList: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated list of installed custom node packages." - required: - - nodes - properties: - nodes: - type: array - items: - $ref: "#/components/schemas/CloudNode" - total: - type: integer - has_more: - type: boolean - HubLabel: type: object x-runtime: [cloud] From 8b08bfdcbe2b4cd8f4426bd1111aaf17b118e33d Mon Sep 17 00:00:00 2001 From: lin-bot23 Date: Sat, 9 May 2026 12:26:13 +0900 Subject: [PATCH 101/102] Add description field to blueprint subgraphs (#13797) * Add description field to all blueprint subgraphs Sets the 'description' field on every subgraph blueprint node, which will show on the node preview and tooltip. Covers all 51 blueprint files under blueprints/. * Update blueprint descriptions with researched model info * Refine blueprint descriptions with researched model specs from docs Updates subgraph descriptions across all 51 blueprints with accurate model details drawn from ComfyUI docs, including: - Flux.1 Dev: 12B open-weights, Pro-level quality - Flux.2 Klein 4B: fastest Flux, distilled architecture - Qwen-Image: 20B MMDiT, multilingual text rendering - Z-Image-Turbo: distilled 6B DiT, sub-second inference - LTX-2/2.3: 19B DiT audio-video foundation model - Wan2.2: open-source, 14B/1.3B variants - ACE-Step 1.5: ~1s full-song generation - GPU shader nodes consistently labeled as fragment shaders * Strip marketing fluff and license info from descriptions * Fix Canny to Video (LTX 2.0) description * Remove 'local-' prefix from subgraph names * Preserve UTF-8 encoding in JSON files (ensure_ascii=False) * Apply review suggestions from alexisrolland - Rename 'Image to Model (Hunyuan3d 2.1)' -> 'Image to 3D Model (Hunyuan3d 2.1)' - Rename 'Image Upscale(Z-image-Turbo)' -> 'Image Upscale (Z-image-Turbo)' - Rename 'Video Inpaint(Wan2.1 VACE)' -> 'Video Inpaint (Wan 2.1 VACE)' - Use 'Black Forest Labs' branding in Flux descriptions - Use 'Google's Gemini' with possessive in captioning nodes - Normalize 'Wan 2.2' and 'Wan 2.1' spacing in descriptions * fix: revert Color Adjustment.json to preserve original GLSL shader content Only adds the 'description' field without modifying the shader code (which contained Unicode escape \\u2192 that should be preserved). * Apply CodeRabbit review suggestions - Color Adjustment: include vibrance in description - Image Blur: expand to Gaussian/Box/Radial modes - Flux.2 Klein 4B: narrow to image edit only (no T2I) - NetaYume Lumina: correct model base (Neta Lumina, not Lumina-Next) --------- Co-authored-by: linmoumou Co-authored-by: Daxiong (Lin) --- blueprints/Brightness and Contrast.json | 5 +++-- blueprints/Canny to Image (Z-Image-Turbo).json | 7 ++++--- blueprints/Canny to Video (LTX 2.0).json | 7 ++++--- blueprints/Chromatic Aberration.json | 5 +++-- blueprints/Color Adjustment.json | 3 ++- blueprints/Color Balance.json | 3 ++- blueprints/Color Curves.json | 3 ++- blueprints/Crop Images 2x2.json | 3 ++- blueprints/Crop Images 3x3.json | 3 ++- blueprints/Depth to Image (Z-Image-Turbo).json | 6 ++++-- blueprints/Depth to Video (ltx 2.0).json | 6 ++++-- blueprints/Edge-Preserving Blur.json | 5 +++-- blueprints/Film Grain.json | 5 +++-- blueprints/First-Last-Frame to Video (LTX-2.3).json | 3 ++- blueprints/Glow.json | 5 +++-- blueprints/Hue and Saturation.json | 5 +++-- blueprints/Image Blur.json | 3 ++- blueprints/Image Captioning (gemini).json | 3 ++- blueprints/Image Channels.json | 5 +++-- blueprints/Image Edit (FireRed Image Edit 1.1).json | 3 ++- blueprints/Image Edit (Flux.2 Klein 4B).json | 8 +++++--- blueprints/Image Edit (LongCat Image Edit).json | 3 ++- blueprints/Image Edit (Qwen 2511).json | 7 ++++--- blueprints/Image Inpainting (Flux.1 Fill Dev).json | 5 +++-- blueprints/Image Inpainting (Qwen-image).json | 6 ++++-- blueprints/Image Levels.json | 5 +++-- blueprints/Image Outpainting (Qwen-Image).json | 9 ++++++--- blueprints/Image Upscale(Z-image-Turbo).json | 5 +++-- blueprints/Image to Depth Map (Lotus).json | 7 ++++--- blueprints/Image to Layers(Qwen-Image-Layered).json | 3 ++- blueprints/Image to Model (Hunyuan3d 2.1).json | 5 +++-- blueprints/Image to Video (LTX-2.3).json | 3 ++- blueprints/Image to Video (Wan 2.2).json | 5 +++-- blueprints/Pose to Image (Z-Image-Turbo).json | 7 ++++--- blueprints/Pose to Video (LTX 2.0).json | 3 ++- blueprints/Prompt Enhance.json | 5 +++-- blueprints/Sharpen.json | 5 +++-- blueprints/Text to Audio (ACE-Step 1.5).json | 7 ++++--- blueprints/Text to Image (Flux.1 Dev).json | 5 +++-- blueprints/Text to Image (Flux.1 Krea Dev).json | 5 +++-- blueprints/Text to Image (NetaYume Lumina).json | 8 +++++--- blueprints/Text to Image (Qwen-Image 2512).json | 3 ++- blueprints/Text to Image (Qwen-Image).json | 3 ++- blueprints/Text to Image (Z-Image-Turbo).json | 7 ++++--- blueprints/Text to Video (LTX-2.3).json | 3 ++- blueprints/Text to Video (Wan 2.2).json | 5 +++-- blueprints/Unsharp Mask.json | 5 +++-- blueprints/Video Captioning (Gemini).json | 3 ++- blueprints/Video Inpaint(Wan2.1 VACE).json | 5 +++-- blueprints/Video Stitch.json | 5 +++-- blueprints/Video Upscale(GAN x4).json | 5 +++-- 51 files changed, 153 insertions(+), 95 deletions(-) diff --git a/blueprints/Brightness and Contrast.json b/blueprints/Brightness and Contrast.json index 90bfe999d..78fc52f29 100644 --- a/blueprints/Brightness and Contrast.json +++ b/blueprints/Brightness and Contrast.json @@ -431,9 +431,10 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adjusts image brightness and contrast using a real-time GPU fragment shader." } ] }, "extra": {} -} +} \ No newline at end of file diff --git a/blueprints/Canny to Image (Z-Image-Turbo).json b/blueprints/Canny to Image (Z-Image-Turbo).json index ff9717308..14deb64cc 100644 --- a/blueprints/Canny to Image (Z-Image-Turbo).json +++ b/blueprints/Canny to Image (Z-Image-Turbo).json @@ -162,7 +162,7 @@ }, "revision": 0, "config": {}, - "name": "local-Canny to Image (Z-Image-Turbo)", + "name": "Canny to Image (Z-Image-Turbo)", "inputNode": { "id": -10, "bounding": [ @@ -1553,7 +1553,8 @@ "VHS_MetadataImage": true, "VHS_KeepIntermediate": true }, - "category": "Image generation and editing/Canny to image" + "category": "Image generation and editing/Canny to image", + "description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning." } ] }, @@ -1574,4 +1575,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Canny to Video (LTX 2.0).json b/blueprints/Canny to Video (LTX 2.0).json index fae8321b9..a9682c8a4 100644 --- a/blueprints/Canny to Video (LTX 2.0).json +++ b/blueprints/Canny to Video (LTX 2.0).json @@ -192,7 +192,7 @@ }, "revision": 0, "config": {}, - "name": "local-Canny to Video (LTX 2.0)", + "name": "Canny to Video (LTX 2.0)", "inputNode": { "id": -10, "bounding": [ @@ -3600,7 +3600,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Canny to video" + "category": "Video generation and editing/Canny to video", + "description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio." } ] }, @@ -3616,4 +3617,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Chromatic Aberration.json b/blueprints/Chromatic Aberration.json index ae8037b1b..893fb1190 100644 --- a/blueprints/Chromatic Aberration.json +++ b/blueprints/Chromatic Aberration.json @@ -377,8 +377,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Color Adjustment.json b/blueprints/Color Adjustment.json index 622bf28af..5abbf8baa 100644 --- a/blueprints/Color Adjustment.json +++ b/blueprints/Color Adjustment.json @@ -596,7 +596,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader." } ] } diff --git a/blueprints/Color Balance.json b/blueprints/Color Balance.json index 21d6319ed..d921eab37 100644 --- a/blueprints/Color Balance.json +++ b/blueprints/Color Balance.json @@ -1129,7 +1129,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader." } ] } diff --git a/blueprints/Color Curves.json b/blueprints/Color Curves.json index 1461cf396..b9bfb7029 100644 --- a/blueprints/Color Curves.json +++ b/blueprints/Color Curves.json @@ -608,7 +608,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader." } ] } diff --git a/blueprints/Crop Images 2x2.json b/blueprints/Crop Images 2x2.json index 2aa42cfc3..99b89b608 100644 --- a/blueprints/Crop Images 2x2.json +++ b/blueprints/Crop Images 2x2.json @@ -1609,7 +1609,8 @@ } ], "extra": {}, - "category": "Image Tools/Crop" + "category": "Image Tools/Crop", + "description": "Splits an image into a 2×2 grid of four equal tiles." } ] }, diff --git a/blueprints/Crop Images 3x3.json b/blueprints/Crop Images 3x3.json index 3a3615ac8..6ac636da4 100644 --- a/blueprints/Crop Images 3x3.json +++ b/blueprints/Crop Images 3x3.json @@ -2946,7 +2946,8 @@ } ], "extra": {}, - "category": "Image Tools/Crop" + "category": "Image Tools/Crop", + "description": "Splits an image into a 3×3 grid of nine equal tiles." } ] }, diff --git a/blueprints/Depth to Image (Z-Image-Turbo).json b/blueprints/Depth to Image (Z-Image-Turbo).json index 4f69a8149..fe9ef0f72 100644 --- a/blueprints/Depth to Image (Z-Image-Turbo).json +++ b/blueprints/Depth to Image (Z-Image-Turbo).json @@ -1579,7 +1579,8 @@ "VHS_MetadataImage": true, "VHS_KeepIntermediate": true }, - "category": "Image generation and editing/Depth to image" + "category": "Image generation and editing/Depth to image", + "description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning." }, { "id": "458bdf3c-4b58-421c-af50-c9c663a4d74c", @@ -2461,7 +2462,8 @@ ] }, "workflowRendererVersion": "LG" - } + }, + "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model." } ] }, diff --git a/blueprints/Depth to Video (ltx 2.0).json b/blueprints/Depth to Video (ltx 2.0).json index f15212520..bb28695a2 100644 --- a/blueprints/Depth to Video (ltx 2.0).json +++ b/blueprints/Depth to Video (ltx 2.0).json @@ -4233,7 +4233,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Depth to video" + "category": "Video generation and editing/Depth to video", + "description": "Generates video from depth maps using LTX-2, with optional synchronized audio." }, { "id": "38b60539-50a7-42f9-a5fe-bdeca26272e2", @@ -5192,7 +5193,8 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model." } ] }, diff --git a/blueprints/Edge-Preserving Blur.json b/blueprints/Edge-Preserving Blur.json index 18012beb1..fbda9f126 100644 --- a/blueprints/Edge-Preserving Blur.json +++ b/blueprints/Edge-Preserving Blur.json @@ -450,9 +450,10 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Blur" + "category": "Image Tools/Blur", + "description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail." } ] }, "extra": {} -} +} \ No newline at end of file diff --git a/blueprints/Film Grain.json b/blueprints/Film Grain.json index a680b3ece..3226ea9aa 100644 --- a/blueprints/Film Grain.json +++ b/blueprints/Film Grain.json @@ -580,8 +580,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader." } ] } -} +} \ No newline at end of file diff --git a/blueprints/First-Last-Frame to Video (LTX-2.3).json b/blueprints/First-Last-Frame to Video (LTX-2.3).json index 8ec9ed61a..f509aefe0 100644 --- a/blueprints/First-Last-Frame to Video (LTX-2.3).json +++ b/blueprints/First-Last-Frame to Video (LTX-2.3).json @@ -3350,7 +3350,8 @@ } ], "extra": {}, - "category": "Video generation and editing/First-Last-Frame to Video" + "category": "Video generation and editing/First-Last-Frame to Video", + "description": "Generates a video interpolating between first and last keyframes using LTX-2.3." } ] }, diff --git a/blueprints/Glow.json b/blueprints/Glow.json index 1dafb2d35..2bbfdee51 100644 --- a/blueprints/Glow.json +++ b/blueprints/Glow.json @@ -575,8 +575,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Hue and Saturation.json b/blueprints/Hue and Saturation.json index 1a2df8937..cddf0154a 100644 --- a/blueprints/Hue and Saturation.json +++ b/blueprints/Hue and Saturation.json @@ -752,8 +752,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Image Blur.json b/blueprints/Image Blur.json index 3c7a784b0..0ca8d9931 100644 --- a/blueprints/Image Blur.json +++ b/blueprints/Image Blur.json @@ -374,7 +374,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Blur" + "category": "Image Tools/Blur", + "description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects." } ] } diff --git a/blueprints/Image Captioning (gemini).json b/blueprints/Image Captioning (gemini).json index 98cfb8999..2fc5d6746 100644 --- a/blueprints/Image Captioning (gemini).json +++ b/blueprints/Image Captioning (gemini).json @@ -310,7 +310,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Text generation/Image Captioning" + "category": "Text generation/Image Captioning", + "description": "Generates descriptive captions for images using Google's Gemini multimodal LLM." } ] } diff --git a/blueprints/Image Channels.json b/blueprints/Image Channels.json index 9c7b675b2..b6fdff5be 100644 --- a/blueprints/Image Channels.json +++ b/blueprints/Image Channels.json @@ -315,8 +315,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Manipulates individual RGBA channels for masking, compositing, and channel effects." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Image Edit (FireRed Image Edit 1.1).json b/blueprints/Image Edit (FireRed Image Edit 1.1).json index c34246ce6..14310353c 100644 --- a/blueprints/Image Edit (FireRed Image Edit 1.1).json +++ b/blueprints/Image Edit (FireRed Image Edit 1.1).json @@ -2138,7 +2138,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Edit image" + "category": "Image generation and editing/Edit image", + "description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model." } ] }, diff --git a/blueprints/Image Edit (Flux.2 Klein 4B).json b/blueprints/Image Edit (Flux.2 Klein 4B).json index 6f2f7dc01..7f6fa7a4b 100644 --- a/blueprints/Image Edit (Flux.2 Klein 4B).json +++ b/blueprints/Image Edit (Flux.2 Klein 4B).json @@ -1472,7 +1472,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Edit image" + "category": "Image generation and editing/Edit image", + "description": "Edits an input image via text instructions using FLUX.2 [klein] 4B." }, { "id": "6007e698-2ebd-4917-84d8-299b35d7b7ab", @@ -1821,7 +1822,8 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)." } ] }, @@ -1837,4 +1839,4 @@ } }, "version": 0.4 -} \ No newline at end of file +} diff --git a/blueprints/Image Edit (LongCat Image Edit).json b/blueprints/Image Edit (LongCat Image Edit).json index 5b4eb18f0..de1c155a2 100644 --- a/blueprints/Image Edit (LongCat Image Edit).json +++ b/blueprints/Image Edit (LongCat Image Edit).json @@ -1417,7 +1417,8 @@ } ], "extra": {}, - "category": "Image generation and editing/Edit image" + "category": "Image generation and editing/Edit image", + "description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model." } ] }, diff --git a/blueprints/Image Edit (Qwen 2511).json b/blueprints/Image Edit (Qwen 2511).json index 582171fa0..1aa7e5765 100644 --- a/blueprints/Image Edit (Qwen 2511).json +++ b/blueprints/Image Edit (Qwen 2511).json @@ -132,7 +132,7 @@ }, "revision": 0, "config": {}, - "name": "local-Image Edit (Qwen 2511)", + "name": "Image Edit (Qwen 2511)", "inputNode": { "id": -10, "bounding": [ @@ -1468,7 +1468,8 @@ "VHS_MetadataImage": true, "VHS_KeepIntermediate": true }, - "category": "Image generation and editing/Edit image" + "category": "Image generation and editing/Edit image", + "description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA." } ] }, @@ -1489,4 +1490,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Image Inpainting (Flux.1 Fill Dev).json b/blueprints/Image Inpainting (Flux.1 Fill Dev).json index d40d63594..c1326ed3d 100644 --- a/blueprints/Image Inpainting (Flux.1 Fill Dev).json +++ b/blueprints/Image Inpainting (Flux.1 Fill Dev).json @@ -1188,7 +1188,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Inpaint image" + "category": "Image generation and editing/Inpaint image", + "description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model." } ] }, @@ -1202,4 +1203,4 @@ }, "ue_links": [] } -} \ No newline at end of file +} diff --git a/blueprints/Image Inpainting (Qwen-image).json b/blueprints/Image Inpainting (Qwen-image).json index 95b2909fa..a06d57e19 100644 --- a/blueprints/Image Inpainting (Qwen-image).json +++ b/blueprints/Image Inpainting (Qwen-image).json @@ -1548,7 +1548,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Inpaint image" + "category": "Image generation and editing/Inpaint image", + "description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks." }, { "id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8", @@ -1907,7 +1908,8 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Expands and softens mask edges to reduce visible seams after image processing." } ] }, diff --git a/blueprints/Image Levels.json b/blueprints/Image Levels.json index ef256a1aa..1a1b18932 100644 --- a/blueprints/Image Levels.json +++ b/blueprints/Image Levels.json @@ -742,9 +742,10 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Color adjust" + "category": "Image Tools/Color adjust", + "description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader." } ] }, "extra": {} -} +} \ No newline at end of file diff --git a/blueprints/Image Outpainting (Qwen-Image).json b/blueprints/Image Outpainting (Qwen-Image).json index 218fdc775..6c07227c0 100644 --- a/blueprints/Image Outpainting (Qwen-Image).json +++ b/blueprints/Image Outpainting (Qwen-Image).json @@ -1919,7 +1919,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Outpaint image" + "category": "Image generation and editing/Outpaint image", + "description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities." }, { "id": "f93c215e-c393-460e-9534-ed2c3d8a652e", @@ -2278,7 +2279,8 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Expands and softens mask edges to reduce visible seams after image processing." }, { "id": "2a4b2cc0-db37-4302-a067-da392f38f06b", @@ -2733,7 +2735,8 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Scales both image and mask together while preserving alignment for editing workflows." } ] }, diff --git a/blueprints/Image Upscale(Z-image-Turbo).json b/blueprints/Image Upscale(Z-image-Turbo).json index 0d2b6e240..bd803a0b1 100644 --- a/blueprints/Image Upscale(Z-image-Turbo).json +++ b/blueprints/Image Upscale(Z-image-Turbo).json @@ -141,7 +141,7 @@ }, "revision": 0, "config": {}, - "name": "local-Image Upscale(Z-image-Turbo)", + "name": "Image Upscale (Z-image-Turbo)", "inputNode": { "id": -10, "bounding": [ @@ -1302,7 +1302,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Enhance" + "category": "Image generation and editing/Enhance", + "description": "Upscales images to higher resolution using Z-Image-Turbo." } ] }, diff --git a/blueprints/Image to Depth Map (Lotus).json b/blueprints/Image to Depth Map (Lotus).json index 089f2cd42..12f10ba5b 100644 --- a/blueprints/Image to Depth Map (Lotus).json +++ b/blueprints/Image to Depth Map (Lotus).json @@ -99,7 +99,7 @@ }, "revision": 0, "config": {}, - "name": "local-Image to Depth Map (Lotus)", + "name": "Image to Depth Map (Lotus)", "inputNode": { "id": -10, "bounding": [ @@ -948,7 +948,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Depth to image" + "category": "Image generation and editing/Depth to image", + "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model." } ] }, @@ -964,4 +965,4 @@ "workflowRendererVersion": "LG" }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Image to Layers(Qwen-Image-Layered).json b/blueprints/Image to Layers(Qwen-Image-Layered).json index 8a525e7a5..7b44f0563 100644 --- a/blueprints/Image to Layers(Qwen-Image-Layered).json +++ b/blueprints/Image to Layers(Qwen-Image-Layered).json @@ -1586,7 +1586,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Image to layers" + "category": "Image generation and editing/Image to layers", + "description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered." } ] }, diff --git a/blueprints/Image to Model (Hunyuan3d 2.1).json b/blueprints/Image to Model (Hunyuan3d 2.1).json index 4705603a8..ee5552656 100644 --- a/blueprints/Image to Model (Hunyuan3d 2.1).json +++ b/blueprints/Image to Model (Hunyuan3d 2.1).json @@ -72,7 +72,7 @@ }, "revision": 0, "config": {}, - "name": "local-Image to Model (Hunyuan3d 2.1)", + "name": "Image to 3D Model (Hunyuan3d 2.1)", "inputNode": { "id": -10, "bounding": [ @@ -765,7 +765,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "3D/Image to 3D Model" + "category": "3D/Image to 3D Model", + "description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1." } ] }, diff --git a/blueprints/Image to Video (LTX-2.3).json b/blueprints/Image to Video (LTX-2.3).json index 86a601130..3db524ea0 100644 --- a/blueprints/Image to Video (LTX-2.3).json +++ b/blueprints/Image to Video (LTX-2.3).json @@ -4223,7 +4223,8 @@ "extra": { "workflowRendererVersion": "Vue-corrected" }, - "category": "Video generation and editing/Image to video" + "category": "Video generation and editing/Image to video", + "description": "Generates video from a single input image using LTX-2.3." } ] }, diff --git a/blueprints/Image to Video (Wan 2.2).json b/blueprints/Image to Video (Wan 2.2).json index a8dafd3c9..3510aad18 100644 --- a/blueprints/Image to Video (Wan 2.2).json +++ b/blueprints/Image to Video (Wan 2.2).json @@ -206,7 +206,7 @@ }, "revision": 0, "config": {}, - "name": "local-Image to Video (Wan 2.2)", + "name": "Image to Video (Wan 2.2)", "inputNode": { "id": -10, "bounding": [ @@ -2027,7 +2027,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Image to video" + "category": "Video generation and editing/Image to video", + "description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V." } ] }, diff --git a/blueprints/Pose to Image (Z-Image-Turbo).json b/blueprints/Pose to Image (Z-Image-Turbo).json index a55410ba4..5c2749efe 100644 --- a/blueprints/Pose to Image (Z-Image-Turbo).json +++ b/blueprints/Pose to Image (Z-Image-Turbo).json @@ -134,7 +134,7 @@ }, "revision": 0, "config": {}, - "name": "local-Pose to Image (Z-Image-Turbo)", + "name": "Pose to Image (Z-Image-Turbo)", "inputNode": { "id": -10, "bounding": [ @@ -1298,7 +1298,8 @@ "VHS_MetadataImage": true, "VHS_KeepIntermediate": true }, - "category": "Image generation and editing/Pose to image" + "category": "Image generation and editing/Pose to image", + "description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning." } ] }, @@ -1319,4 +1320,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Pose to Video (LTX 2.0).json b/blueprints/Pose to Video (LTX 2.0).json index 580900bc0..1ce49351a 100644 --- a/blueprints/Pose to Video (LTX 2.0).json +++ b/blueprints/Pose to Video (LTX 2.0).json @@ -3870,7 +3870,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Pose to video" + "category": "Video generation and editing/Pose to video", + "description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio." } ] }, diff --git a/blueprints/Prompt Enhance.json b/blueprints/Prompt Enhance.json index 5e57548ff..e260b1203 100644 --- a/blueprints/Prompt Enhance.json +++ b/blueprints/Prompt Enhance.json @@ -270,9 +270,10 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Text generation/Prompt enhance" + "category": "Text generation/Prompt enhance", + "description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality." } ] }, "extra": {} -} +} \ No newline at end of file diff --git a/blueprints/Sharpen.json b/blueprints/Sharpen.json index f332400fd..3c4099c6b 100644 --- a/blueprints/Sharpen.json +++ b/blueprints/Sharpen.json @@ -302,8 +302,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Sharpen" + "category": "Image Tools/Sharpen", + "description": "Sharpens image details using a GPU fragment shader for enhanced clarity." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Text to Audio (ACE-Step 1.5).json b/blueprints/Text to Audio (ACE-Step 1.5).json index 206cf16be..5b8b8626f 100644 --- a/blueprints/Text to Audio (ACE-Step 1.5).json +++ b/blueprints/Text to Audio (ACE-Step 1.5).json @@ -222,7 +222,7 @@ }, "revision": 0, "config": {}, - "name": "local-Text to Audio (ACE-Step 1.5)", + "name": "Text to Audio (ACE-Step 1.5)", "inputNode": { "id": -10, "bounding": [ @@ -1502,7 +1502,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Audio/Music generation" + "category": "Audio/Music generation", + "description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model." } ] }, @@ -1518,4 +1519,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Text to Image (Flux.1 Dev).json b/blueprints/Text to Image (Flux.1 Dev).json index 04c3cb95a..45f68f508 100644 --- a/blueprints/Text to Image (Flux.1 Dev).json +++ b/blueprints/Text to Image (Flux.1 Dev).json @@ -1029,7 +1029,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model." } ] }, @@ -1043,4 +1044,4 @@ }, "ue_links": [] } -} \ No newline at end of file +} diff --git a/blueprints/Text to Image (Flux.1 Krea Dev).json b/blueprints/Text to Image (Flux.1 Krea Dev).json index fe4db1cfc..30a78dca1 100644 --- a/blueprints/Text to Image (Flux.1 Krea Dev).json +++ b/blueprints/Text to Image (Flux.1 Krea Dev).json @@ -1023,7 +1023,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant." } ] }, @@ -1037,4 +1038,4 @@ }, "ue_links": [] } -} \ No newline at end of file +} diff --git a/blueprints/Text to Image (NetaYume Lumina).json b/blueprints/Text to Image (NetaYume Lumina).json index 394ad1608..9e11b7a86 100644 --- a/blueprints/Text to Image (NetaYume Lumina).json +++ b/blueprints/Text to Image (NetaYume Lumina).json @@ -1104,7 +1104,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation." }, { "id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9", @@ -1458,11 +1459,12 @@ ], "extra": { "workflowRendererVersion": "LG" - } + }, + "description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)." } ] }, "extra": { "ue_links": [] } -} \ No newline at end of file +} diff --git a/blueprints/Text to Image (Qwen-Image 2512).json b/blueprints/Text to Image (Qwen-Image 2512).json index f52ea2ef2..09612be8b 100644 --- a/blueprints/Text to Image (Qwen-Image 2512).json +++ b/blueprints/Text to Image (Qwen-Image 2512).json @@ -1941,7 +1941,8 @@ "extra": { "workflowRendererVersion": "Vue-corrected" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version." } ] }, diff --git a/blueprints/Text to Image (Qwen-Image).json b/blueprints/Text to Image (Qwen-Image).json index 70b4b44b3..e78d5a962 100644 --- a/blueprints/Text to Image (Qwen-Image).json +++ b/blueprints/Text to Image (Qwen-Image).json @@ -1873,7 +1873,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering." } ] }, diff --git a/blueprints/Text to Image (Z-Image-Turbo).json b/blueprints/Text to Image (Z-Image-Turbo).json index 6aa80e327..6975151ea 100644 --- a/blueprints/Text to Image (Z-Image-Turbo).json +++ b/blueprints/Text to Image (Z-Image-Turbo).json @@ -149,7 +149,7 @@ }, "revision": 0, "config": {}, - "name": "local-Text to Image (Z-Image-Turbo)", + "name": "Text to Image (Z-Image-Turbo)", "inputNode": { "id": -10, "bounding": [ @@ -1054,7 +1054,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image generation and editing/Text to image" + "category": "Image generation and editing/Text to image", + "description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model." } ] }, @@ -1075,4 +1076,4 @@ } }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Text to Video (LTX-2.3).json b/blueprints/Text to Video (LTX-2.3).json index ff9bc6ccf..f44a216dd 100644 --- a/blueprints/Text to Video (LTX-2.3).json +++ b/blueprints/Text to Video (LTX-2.3).json @@ -4286,7 +4286,8 @@ "extra": { "workflowRendererVersion": "Vue-corrected" }, - "category": "Video generation and editing/Text to video" + "category": "Video generation and editing/Text to video", + "description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model." } ] }, diff --git a/blueprints/Text to Video (Wan 2.2).json b/blueprints/Text to Video (Wan 2.2).json index 0ce485b67..a264a490d 100644 --- a/blueprints/Text to Video (Wan 2.2).json +++ b/blueprints/Text to Video (Wan 2.2).json @@ -1572,7 +1572,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Text to video" + "category": "Video generation and editing/Text to video", + "description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model." } ] }, @@ -1586,4 +1587,4 @@ "VHS_KeepIntermediate": true }, "version": 0.4 -} +} \ No newline at end of file diff --git a/blueprints/Unsharp Mask.json b/blueprints/Unsharp Mask.json index 137acaa43..79a4c954f 100644 --- a/blueprints/Unsharp Mask.json +++ b/blueprints/Unsharp Mask.json @@ -434,8 +434,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Image Tools/Sharpen" + "category": "Image Tools/Sharpen", + "description": "Enhances edge contrast via unsharp masking for a sharper image appearance." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Video Captioning (Gemini).json b/blueprints/Video Captioning (Gemini).json index ea6dc8bee..7642b23c1 100644 --- a/blueprints/Video Captioning (Gemini).json +++ b/blueprints/Video Captioning (Gemini).json @@ -307,7 +307,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Text generation/Video Captioning" + "category": "Text generation/Video Captioning", + "description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM." } ] } diff --git a/blueprints/Video Inpaint(Wan2.1 VACE).json b/blueprints/Video Inpaint(Wan2.1 VACE).json index f404e6773..a658be5f8 100644 --- a/blueprints/Video Inpaint(Wan2.1 VACE).json +++ b/blueprints/Video Inpaint(Wan2.1 VACE).json @@ -165,7 +165,7 @@ }, "revision": 0, "config": {}, - "name": "local-Video Inpaint(Wan2.1 VACE)", + "name": "Video Inpaint (Wan 2.1 VACE)", "inputNode": { "id": -10, "bounding": [ @@ -2368,7 +2368,8 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Inpaint video" + "category": "Video generation and editing/Inpaint video", + "description": "Inpaints masked regions in video frames using Wan 2.1 VACE." } ] }, diff --git a/blueprints/Video Stitch.json b/blueprints/Video Stitch.json index 020896d78..6eb0f0bbf 100644 --- a/blueprints/Video Stitch.json +++ b/blueprints/Video Stitch.json @@ -584,8 +584,9 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video Tools/Stitch videos" + "category": "Video Tools/Stitch videos", + "description": "Stitches multiple video clips into a single sequential video file." } ] } -} +} \ No newline at end of file diff --git a/blueprints/Video Upscale(GAN x4).json b/blueprints/Video Upscale(GAN x4).json index b61dc88d7..73476e36b 100644 --- a/blueprints/Video Upscale(GAN x4).json +++ b/blueprints/Video Upscale(GAN x4).json @@ -412,9 +412,10 @@ "extra": { "workflowRendererVersion": "LG" }, - "category": "Video generation and editing/Enhance video" + "category": "Video generation and editing/Enhance video", + "description": "Upscales video to 4× resolution using a GAN-based upscaling model." } ] }, "extra": {} -} +} \ No newline at end of file From 7bbf1e8169fa3080841b83914fa9901793b66b71 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 9 May 2026 07:38:17 +0300 Subject: [PATCH 102/102] [Partner Nodes] Tripo3D 3.1 model (#13788) * feat(api-nodes): add Tripo3D 3.1 model Signed-off-by: bigcat88 * fix: price badges algo Signed-off-by: bigcat88 * [Partner Nodes] deprecate "quad" param for the TripoMultiviewToModel node Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/tripo.py | 30 ++++-------- comfy_api_nodes/nodes_tripo.py | 84 +++++++++++----------------------- 2 files changed, 36 insertions(+), 78 deletions(-) diff --git a/comfy_api_nodes/apis/tripo.py b/comfy_api_nodes/apis/tripo.py index ffaaa7dc1..bce6b0e89 100644 --- a/comfy_api_nodes/apis/tripo.py +++ b/comfy_api_nodes/apis/tripo.py @@ -1,10 +1,11 @@ -from __future__ import annotations from enum import Enum -from typing import Optional, List, Dict, Any, Union +from typing import Optional, Any from pydantic import BaseModel, Field, RootModel + class TripoModelVersion(str, Enum): + v3_1_20260211 = 'v3.1-20260211' v3_0_20250812 = 'v3.0-20250812' v2_5_20250123 = 'v2.5-20250123' v2_0_20240919 = 'v2.0-20240919' @@ -142,7 +143,7 @@ class TripoFileEmptyReference(BaseModel): pass class TripoFileReference(RootModel): - root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference] + root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference class TripoGetStsTokenRequest(BaseModel): format: str = Field(..., description='The format of the image') @@ -183,7 +184,7 @@ class TripoImageToModelRequest(BaseModel): class TripoMultiviewToModelRequest(BaseModel): type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL - files: List[TripoFileReference] = Field(..., description='The file references to convert to a model') + files: list[TripoFileReference] = Field(..., description='The file references to convert to a model') model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection') face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') @@ -251,27 +252,13 @@ class TripoConvertModelRequest(BaseModel): with_animation: Optional[bool] = Field(None, description='Whether to include animations') pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') bake: Optional[bool] = Field(None, description='Whether to bake the model') - part_names: Optional[List[str]] = Field(None, description='The names of the parts to include') + part_names: Optional[list[str]] = Field(None, description='The names of the parts to include') fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') -class TripoTaskRequest(RootModel): - root: Union[ - TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, - TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimatePrerigcheckRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoStylizeModelRequest, - TripoConvertModelRequest - ] - class TripoTaskOutput(BaseModel): model: Optional[str] = Field(None, description='URL to the model') base_model: Optional[str] = Field(None, description='URL to the base model') @@ -283,12 +270,13 @@ class TripoTask(BaseModel): task_id: str = Field(..., description='The task ID') type: Optional[str] = Field(None, description='The type of task') status: Optional[TripoTaskStatus] = Field(None, description='The status of the task') - input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task') + input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task') output: Optional[TripoTaskOutput] = Field(None, description='The output of the task') progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100) create_time: Optional[int] = Field(None, description='The creation time of the task') running_left_time: Optional[int] = Field(None, description='The estimated time left for the task') queue_position: Optional[int] = Field(None, description='The position in the queue') + consumed_credit: int | None = Field(None) class TripoTaskResponse(BaseModel): code: int = Field(0, description='The response code') @@ -296,7 +284,7 @@ class TripoTaskResponse(BaseModel): class TripoGeneralResponse(BaseModel): code: int = Field(0, description='The response code') - data: Dict[str, str] = Field(..., description='The task ID data') + data: dict[str, str] = Field(..., description='The task ID data') class TripoBalanceData(BaseModel): balance: float = Field(..., description='The account balance') diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 9f4298dce..d6501dee4 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -60,6 +60,7 @@ async def poll_until_finished( ], status_extractor=lambda x: x.data.status, progress_extractor=lambda x: x.data.progress, + price_extractor=lambda x: x.data.consumed_credit * 0.01 if x.data.consumed_credit else None, estimated_duration=average_duration, ) if response_poll.data.status == TripoTaskStatus.SUCCESS: @@ -113,7 +114,6 @@ class TripoTextToModelNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends( widgets=[ "model_version", - "style", "texture", "pbr", "quad", @@ -124,20 +124,17 @@ class TripoTextToModelNode(IO.ComfyNode): expr=""" ( $isV14 := $contains(widgets.model_version,"v1.4"); - $style := widgets.style; - $hasStyle := ($style != "" and $style != "none"); + $isV3OrLater := $contains(widgets.model_version,"v3."); $withTexture := widgets.texture or widgets.pbr; $isHdTexture := (widgets.texture_quality = "detailed"); $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 20 : ($withTexture ? 20 : 10); - $credits := - $baseCredits - + ($hasStyle ? 5 : 0) + $credits := $isV14 ? 20 : ( + ($withTexture ? 20 : 10) + (widgets.quad ? 5 : 0) + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} + + (($isDetailedGeometry and $isV3OrLater) ? 20 : 0) + ); + {"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}} ) """, ), @@ -239,7 +236,6 @@ class TripoImageToModelNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends( widgets=[ "model_version", - "style", "texture", "pbr", "quad", @@ -250,20 +246,17 @@ class TripoImageToModelNode(IO.ComfyNode): expr=""" ( $isV14 := $contains(widgets.model_version,"v1.4"); - $style := widgets.style; - $hasStyle := ($style != "" and $style != "none"); + $isV3OrLater := $contains(widgets.model_version,"v3."); $withTexture := widgets.texture or widgets.pbr; $isHdTexture := (widgets.texture_quality = "detailed"); $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 30 : ($withTexture ? 30 : 20); - $credits := - $baseCredits - + ($hasStyle ? 5 : 0) + $credits := $isV14 ? 30 : ( + ($withTexture ? 30 : 20) + (widgets.quad ? 5 : 0) + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} + + (($isDetailedGeometry and $isV3OrLater) ? 20 : 0) + ); + {"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}} ) """, ), @@ -358,7 +351,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True), + IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."), IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), ], outputs=[ @@ -379,7 +372,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode): "model_version", "texture", "pbr", - "quad", "texture_quality", "geometry_quality", ], @@ -387,17 +379,16 @@ class TripoMultiviewToModelNode(IO.ComfyNode): expr=""" ( $isV14 := $contains(widgets.model_version,"v1.4"); + $isV3OrLater := $contains(widgets.model_version,"v3."); $withTexture := widgets.texture or widgets.pbr; $isHdTexture := (widgets.texture_quality = "detailed"); $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 30 : ($withTexture ? 30 : 20); - $credits := - $baseCredits - + (widgets.quad ? 5 : 0) + $credits := $isV14 ? 30 : ( + ($withTexture ? 30 : 20) + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} + + (($isDetailedGeometry and $isV3OrLater) ? 20 : 0) + ); + {"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}} ) """, ), @@ -457,7 +448,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): geometry_quality=geometry_quality, texture_alignment=texture_alignment, face_limit=face_limit if face_limit != -1 else None, - quad=quad, + quad=None, ), ) return await poll_until_finished(cls, response, average_duration=80) @@ -498,7 +489,7 @@ class TripoTextureNode(IO.ComfyNode): expr=""" ( $tq := widgets.texture_quality; - {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} + {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1), "format": {"approximate": true}} ) """, ), @@ -555,7 +546,7 @@ class TripoRefineNode(IO.ComfyNode): is_api_node=True, is_output_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.3}""", + expr="""{"type":"usd","usd":0.3, "format": {"approximate": true}}""", ), ) @@ -592,7 +583,7 @@ class TripoRigNode(IO.ComfyNode): is_api_node=True, is_output_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", + expr="""{"type":"usd","usd":0.25, "format": {"approximate": true}}""", ), ) @@ -652,7 +643,7 @@ class TripoRetargetNode(IO.ComfyNode): is_api_node=True, is_output_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.1}""", + expr="""{"type":"usd","usd":0.1, "format": {"approximate": true}}""", ), ) @@ -761,19 +752,10 @@ class TripoConversionNode(IO.ComfyNode): "face_limit", "texture_size", "texture_format", - "force_symmetry", "flatten_bottom", "flatten_bottom_threshold", "pivot_to_center_bottom", "scale_factor", - "with_animation", - "pack_uv", - "bake", - "part_names", - "fbx_preset", - "export_vertex_colors", - "export_orientation", - "animate_in_place", ], ), expr=""" @@ -783,28 +765,16 @@ class TripoConversionNode(IO.ComfyNode): $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); - $part := widgets.part_names; - $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); - $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); $advanced := widgets.quad or - widgets.force_symmetry or widgets.flatten_bottom or widgets.pivot_to_center_bottom or - widgets.with_animation or - widgets.pack_uv or - widgets.bake or - widgets.export_vertex_colors or - widgets.animate_in_place or ($face != -1) or ($texSize != 4096) or ($flatThresh != 0) or ($scale != 1) or - ($texFmt != "jpeg") or - ($part != "") or - ($fbx != "blender") or - ($orient != "default"); - {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} + ($texFmt != "jpeg"); + {"type":"usd","usd": ($advanced ? 0.1 : 0.05), "format": {"approximate": true}} ) """, ),