From 2640acb31ccfddee57ba22d5245bf456e8dffe53 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 1 Dec 2025 14:13:48 -0800
Subject: [PATCH 01/81] Update qwen tokenizer to add qwen 3 tokens. (#11029)
Doesn't actually change anything for current workflows because none of the
current models have a template with the think tokens.
---
.../qwen25_tokenizer/tokenizer_config.json | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
index 67688e82c..df5b5d7fe 100644
--- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
+++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
@@ -179,36 +179,36 @@
"special": false
},
"151665": {
- "content": "<|img|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151666": {
- "content": "<|endofimg|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151667": {
- "content": "<|meta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151668": {
- "content": "<|endofmeta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
}
},
"additional_special_tokens": [
From 1cb7e22a95701f2619d1ddf5683ea221b58a0c13 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 2 Dec 2025 02:11:52 +0200
Subject: [PATCH 02/81] [API Nodes] add Kling O1 model support (#11025)
* feat(api-nodes): add Kling O1 model support
* fix: increase max allowed duration to 10.05 seconds
* fix(VideoInput): respect "format" argument
---
comfy_api/latest/_input_impl/video_types.py | 5 +-
comfy_api_nodes/apis/kling_api.py | 66 +++
comfy_api_nodes/nodes_kling.py | 444 +++++++++++++++++++-
comfy_api_nodes/util/upload_helpers.py | 3 +-
4 files changed, 499 insertions(+), 19 deletions(-)
create mode 100644 comfy_api_nodes/apis/kling_api.py
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index bde37f90a..7231bf13c 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
- with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
+ extra_kwargs = {}
+ if format != VideoContainer.AUTO:
+ extra_kwargs["format"] = format.value
+ with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py
new file mode 100644
index 000000000..0a3b447c5
--- /dev/null
+++ b/comfy_api_nodes/apis/kling_api.py
@@ -0,0 +1,66 @@
+from pydantic import BaseModel, Field
+
+
+class OmniProText2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniParamImage(BaseModel):
+ image_url: str = Field(...)
+ type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
+
+
+class OmniParamVideo(BaseModel):
+ video_url: str = Field(...)
+ refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
+ keep_original_sound: str = Field(..., description="'yes' or 'no'")
+
+
+class OmniProFirstLastFrameRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniProReferences2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
+ image_list: list[OmniParamImage] | None = Field(
+ None, max_length=7, description="Max length 4 when video is present."
+ )
+ video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
+ duration: str | None = Field(..., description="From 3 to 10.")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class TaskStatusVideoResult(BaseModel):
+ duration: str | None = Field(None, description="Total video duration")
+ id: str | None = Field(None, description="Generated video ID")
+ url: str | None = Field(None, description="URL for generated video")
+
+
+class TaskStatusVideoResults(BaseModel):
+ videos: list[TaskStatusVideoResult] | None = Field(None)
+
+
+class TaskStatusVideoResponseData(BaseModel):
+ created_at: int | None = Field(None, description="Task creation time")
+ updated_at: int | None = Field(None, description="Task update time")
+ task_status: str | None = None
+ task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
+ task_id: str | None = Field(None, description="Task ID")
+ task_result: TaskStatusVideoResults | None = Field(None)
+
+
+class TaskStatusVideoResponse(BaseModel):
+ code: int | None = Field(None, description="Error code")
+ message: str | None = Field(None, description="Error message")
+ request_id: str | None = Field(None, description="Request ID")
+ data: TaskStatusVideoResponseData | None = Field(None)
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 23a7f55f1..850c44db6 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -4,13 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
"""
-import math
import logging
-
-from typing_extensions import override
+import math
import torch
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
KlingCameraControl,
KlingCameraConfig,
@@ -48,23 +48,31 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName,
KlingSingleImageEffectModelName,
)
+from comfy_api_nodes.apis.kling_api import (
+ OmniParamImage,
+ OmniParamVideo,
+ OmniProFirstLastFrameRequest,
+ OmniProReferences2VideoRequest,
+ OmniProText2VideoRequest,
+ TaskStatusVideoResponse,
+)
from comfy_api_nodes.util import (
- validate_image_dimensions,
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ download_url_to_video_output,
+ get_number_of_images,
+ poll_op,
+ sync_op,
+ tensor_to_base64_string,
+ upload_audio_to_comfyapi,
+ upload_images_to_comfyapi,
+ upload_video_to_comfyapi,
validate_image_aspect_ratio,
+ validate_image_dimensions,
+ validate_string,
validate_video_dimensions,
validate_video_duration,
- tensor_to_base64_string,
- validate_string,
- upload_audio_to_comfyapi,
- download_url_to_image_tensor,
- upload_video_to_comfyapi,
- download_url_to_video_output,
- sync_op,
- ApiEndpoint,
- poll_op,
)
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import ComfyExtension, IO, Input
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
@@ -202,6 +210,20 @@ VOICES_CONFIG = {
}
+async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
+ response_model=TaskStatusVideoResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs)
@@ -449,7 +471,7 @@ async def execute_video_effect(
image_1: torch.Tensor,
image_2: torch.Tensor | None = None,
model_mode: KlingVideoGenMode | None = None,
-) -> tuple[VideoFromFile, str, str]:
+) -> tuple[InputImpl.VideoFromFile, str, str]:
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
model_name=model_name,
@@ -736,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode):
)
+class OmniProTextToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProTextToVideoNode",
+ display_name="Kling Omni Text to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use text prompts to generate videos with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Combo.Input("duration", options=[5, 10]),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=TaskStatusVideoResponse,
+ data=OmniProText2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProFirstLastFrameNode",
+ display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("duration", options=["5", "10"]),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input(
+ "end_frame",
+ optional=True,
+ tooltip="An optional end frame for the video. "
+ "This cannot be used simultaneously with 'reference_images'.",
+ ),
+ IO.Image.Input(
+ "reference_images",
+ optional=True,
+ tooltip="Up to 6 additional reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image | None = None,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ if end_frame is not None and reference_images is not None:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = [
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
+ type="first_frame",
+ )
+ ]
+ if end_frame is not None:
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_list.append(
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
+ type="end_frame",
+ )
+ )
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 6:
+ raise ValueError("The maximum number of reference images allowed is 6.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=TaskStatusVideoResponse,
+ data=OmniProFirstLastFrameRequest(
+ model_name=model_name,
+ prompt=prompt,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProImageToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageToVideoNode",
+ display_name="Kling Omni Image to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use up to 7 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 7 reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_images: Input.Image,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ if get_number_of_images(reference_images) > 7:
+ raise ValueError("The maximum number of reference images is 7.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = []
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=TaskStatusVideoResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProVideoToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProVideoToVideoNode",
+ display_name="Kling Omni Video to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
+ refer_type="feature",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=TaskStatusVideoResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProEditVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProEditVideoNode",
+ display_name="Kling Omni Edit Video (Pro)",
+ category="api node/video/Kling",
+ description="Edit an existing video with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
+ refer_type="base",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=TaskStatusVideoResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=None,
+ duration=None,
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@@ -1162,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
category="api node/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
- IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
+ IO.Image.Input(
+ "image",
+ tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
+ ),
IO.Combo.Input(
"effect_scene",
options=[i.value for i in KlingSingleImageEffectsScene],
@@ -1525,6 +1930,11 @@ class KlingExtension(ComfyExtension):
KlingImageGenerationNode,
KlingSingleImageVideoEffectNode,
KlingDualCharacterVideoEffectNode,
+ OmniProTextToVideoNode,
+ OmniProFirstLastFrameNode,
+ OmniProImageToVideoNode,
+ OmniProVideoToVideoNode,
+ OmniProEditVideoNode,
]
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index b9019841f..0532bea9a 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -103,6 +103,7 @@ async def upload_video_to_comfyapi(
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
+ wait_label: str | None = "Uploading",
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
@@ -127,7 +128,7 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
- return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
+ return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
async def upload_file_to_comfyapi(
From 30c259cac8c08ff8d015f9aff3151cb525c9b702 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Mon, 1 Dec 2025 20:25:35 -0500
Subject: [PATCH 03/81] ComfyUI version v0.3.76
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index fa4b4f4b0..4b039356e 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.3.75"
+__version__ = "0.3.76"
diff --git a/pyproject.toml b/pyproject.toml
index 9009e65fe..02b94a0ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.3.75"
+version = "0.3.76"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
From 878db3a727c1c6049bc1c4959cdfabc35eaf3d56 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 1 Dec 2025 17:56:17 -0800
Subject: [PATCH 04/81] Implement the Ovis image model. (#11030)
---
comfy/ldm/chroma/model.py | 3 +-
comfy/ldm/flux/layers.py | 68 +++++++++++++++++++++--------------
comfy/ldm/flux/model.py | 21 ++++++++---
comfy/model_detection.py | 10 +++++-
comfy/sd.py | 13 +++++--
comfy/text_encoders/llama.py | 31 ++++++++++++++++
comfy/text_encoders/ovis.py | 69 ++++++++++++++++++++++++++++++++++++
nodes.py | 2 +-
8 files changed, 182 insertions(+), 35 deletions(-)
create mode 100644 comfy/text_encoders/ovis.py
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index a72f8cc47..2e8ef0687 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -40,7 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
-
+ txt_ids_dims: list
+ vec_in_dim: int
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 2472ab79c..60f2bdae2 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
+class YakMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
+ if yak_mlp:
+ return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
+ if mlp_silu_act:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
+ SiLUActivation(),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
+ )
+ else:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
@@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
modulation=True,
mlp_silu_act=False,
bias=True,
+ yak_mlp=False,
dtype=None,
device=None,
operations=None
@@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
+ self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
+ if self.yak_mlp:
+ self.mlp_hidden_dim_first *= 2
+ self.mlp_act = nn.SiLU()
+
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
@@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
- mlp = self.mlp_act(mlp)
+ if self.yak_mlp:
+ mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
+ else:
+ mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index d5674dea6..f40c2a7a9 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -15,7 +15,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
- Modulation
+ Modulation,
+ RMSNorm
)
@dataclass
@@ -34,11 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
+ txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
+ yak_mlp: bool = False
+ txt_norm: bool = False
class Flux(nn.Module):
@@ -76,6 +80,11 @@ class Flux(nn.Module):
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
+ if params.txt_norm:
+ self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ else:
+ self.txt_norm = None
+
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -86,6 +95,7 @@ class Flux(nn.Module):
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
+ yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -94,7 +104,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@@ -150,6 +160,8 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.txt_norm is not None:
+ txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
@@ -332,8 +344,9 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
- if len(self.params.axes_dim) == 4: # Flux 2
- txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
+ if len(self.params.txt_ids_dims) > 0:
+ for i in self.params.txt_ids_dims:
+ txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 7afe4a798..7d0517e61 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
- dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
+ dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
@@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
+ dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
@@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
+ else:
+ dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
+ dit_config["txt_ids_dims"] = [1, 2]
+
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
diff --git a/comfy/sd.py b/comfy/sd.py
index 9eeb0c45a..f9e5efab5 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
+import comfy.text_encoders.ovis
import comfy.model_patcher
import comfy.lora
@@ -956,6 +957,7 @@ class CLIPType(Enum):
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
+ OVIS = 21
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -987,6 +989,7 @@ class TEModel(Enum):
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
+ QWEN3_2B = 17
def detect_te_model(sd):
@@ -1020,9 +1023,12 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
- if 'model.layers.0.self_attn.q_norm.weight' in sd:
- return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
+ if 'model.layers.0.self_attn.q_norm.weight' in sd:
+ if weight.shape[0] == 2560:
+ return TEModel.QWEN3_4B
+ elif weight.shape[0] == 2048:
+ return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@@ -1150,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
+ elif te_model == TEModel.QWEN3_2B:
+ clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index cd4b5f76c..0d07ac8c6 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -100,6 +100,28 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
+@dataclass
+class Ovis25_2BConfig:
+ vocab_size: int = 151936
+ hidden_size: int = 2048
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
+class Ovis25_2B(BaseLlama, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Ovis25_2BConfig(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py
new file mode 100644
index 000000000..81c9bd51c
--- /dev/null
+++ b/comfy/text_encoders/ovis.py
@@ -0,0 +1,69 @@
+from transformers import Qwen2Tokenizer
+import comfy.text_encoders.llama
+from comfy import sd1_clip
+import os
+import torch
+import numbers
+
+class Qwen3Tokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
+
+
+class OvisTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
+ self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
+ if llama_template is None:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+
+ tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
+ return tokens
+
+class Ovis25_2BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
+
+
+class OvisTEModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs, template_end=-1):
+ out, pooled = super().encode_token_weights(token_weight_pairs)
+ tok_pairs = token_weight_pairs["qwen3_2b"][0]
+ count_im_start = 0
+ if template_end == -1:
+ for i, v in enumerate(tok_pairs):
+ elem = v[0]
+ if not torch.is_tensor(elem):
+ if isinstance(elem, numbers.Integral):
+ if elem == 4004 and count_im_start < 1:
+ template_end = i
+ count_im_start += 1
+
+ if out.shape[1] > (template_end + 1):
+ if tok_pairs[template_end + 1][0] == 25:
+ template_end += 1
+
+ out = out[:, template_end:]
+ return out, pooled, {}
+
+
+def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
+ class OvisTEModel_(OvisTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["scaled_fp8"] = llama_scaled_fp8
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ if llama_quantization_metadata is not None:
+ model_options["quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return OvisTEModel_
diff --git a/nodes.py b/nodes.py
index 495dec806..d5e5dc228 100644
--- a/nodes.py
+++ b/nodes.py
@@ -939,7 +939,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
From c55dc857d5da5af203caf720ed7056047d382544 Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Mon, 1 Dec 2025 17:56:38 -0800
Subject: [PATCH 05/81] bump comfyui-frontend-package to 1.33.10 (#11028)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 045b2ac54..f98848e20 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.32.10
+comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25
comfyui-embedded-docs==0.3.1
torch
From b4a20acc54b0b94dc05a1bd09dc0b54dd12203f1 Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Tue, 2 Dec 2025 12:32:52 +0900
Subject: [PATCH 06/81] feat: Support ComfyUI-Manager for pip version (#7555)
---
comfy/cli_args.py | 7 +++++++
comfy_api/feature_flags.py | 1 +
main.py | 30 ++++++++++++++++++++++++++++++
manager_requirements.txt | 1 +
nodes.py | 9 +++++++++
server.py | 8 +++++++-
6 files changed, 55 insertions(+), 1 deletion(-)
create mode 100644 manager_requirements.txt
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 5f0dfaa10..209fc185b 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
+parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
+manager_group = parser.add_mutually_exclusive_group()
+manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
+manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
+
+
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@@ -168,6 +174,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
+
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py
index 0d4389a6e..bfb77eb5f 100644
--- a/comfy_api/feature_flags.py
+++ b/comfy_api/feature_flags.py
@@ -13,6 +13,7 @@ from comfy.cli_args import args
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
+ "extension": {"manager": {"supports_v4": True}},
}
diff --git a/main.py b/main.py
index e1b0f1620..0cd815d9e 100644
--- a/main.py
+++ b/main.py
@@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
+
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@@ -22,6 +23,23 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
+
+def handle_comfyui_manager_unavailable():
+ if not args.windows_standalone_build:
+ logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
+ args.enable_manager = False
+
+
+if args.enable_manager:
+ if importlib.util.find_spec("comfyui_manager"):
+ import comfyui_manager
+
+ if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
+ handle_comfyui_manager_unavailable()
+ else:
+ handle_comfyui_manager_unavailable()
+
+
def apply_custom_paths():
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
@@ -79,6 +97,11 @@ def execute_prestartup_script():
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
+
+ if args.enable_manager:
+ if comfyui_manager.should_be_disabled(module_path):
+ continue
+
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
continue
@@ -101,6 +124,10 @@ def execute_prestartup_script():
logging.info("")
apply_custom_paths()
+
+if args.enable_manager:
+ comfyui_manager.prestartup()
+
execute_prestartup_script()
@@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
+ if args.enable_manager and not args.disable_manager_ui:
+ comfyui_manager.start()
+
hook_breaker_ac10a0.save_functions()
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
diff --git a/manager_requirements.txt b/manager_requirements.txt
new file mode 100644
index 000000000..52cc5389c
--- /dev/null
+++ b/manager_requirements.txt
@@ -0,0 +1 @@
+comfyui_manager==4.0.3b3
diff --git a/nodes.py b/nodes.py
index d5e5dc228..4c910a34b 100644
--- a/nodes.py
+++ b/nodes.py
@@ -43,6 +43,9 @@ import folder_paths
import latent_preview
import node_helpers
+if args.enable_manager:
+ import comfyui_manager
+
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@@ -2243,6 +2246,12 @@ async def init_external_custom_nodes():
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
continue
+
+ if args.enable_manager:
+ if comfyui_manager.should_be_disabled(module_path):
+ logging.info(f"Blocked by policy: {module_path}")
+ continue
+
time_before = time.perf_counter()
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success))
diff --git a/server.py b/server.py
index fca5050bd..e3bd056d9 100644
--- a/server.py
+++ b/server.py
@@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
# Import cache control middleware
from middleware.cache_middleware import cache_control
+if args.enable_manager:
+ import comfyui_manager
+
async def send_socket_catch_exception(function, message):
try:
await function(message)
@@ -212,6 +215,9 @@ class PromptServer():
if args.disable_api_nodes:
middlewares.append(create_block_external_middleware())
+ if args.enable_manager:
+ middlewares.append(comfyui_manager.create_middleware())
+
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
@@ -599,7 +605,7 @@ class PromptServer():
system_stats = {
"system": {
- "os": os.name,
+ "os": sys.platform,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": __version__,
From a17cf1c3871ad582c85c2bb6fddb63ec9c6df0ce Mon Sep 17 00:00:00 2001
From: Yoland Yan <4950057+yoland68@users.noreply.github.com>
Date: Mon, 1 Dec 2025 19:40:44 -0800
Subject: [PATCH 07/81] Add @guill as a code owner (#11031)
---
CODEOWNERS | 1 +
1 file changed, 1 insertion(+)
diff --git a/CODEOWNERS b/CODEOWNERS
index b7aca9b26..51acc4986 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,3 +1,4 @@
# Admins
* @comfyanonymous
* @kosinkadink
+* @guill
From 44baa0b7f32dd0c2ff0a9898aeb6c7929d855cd3 Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Tue, 2 Dec 2025 11:46:29 -0800
Subject: [PATCH 08/81] Fix CODEOWNERS formatting to have all on the same line,
otherwise only last line applies (#11053)
---
CODEOWNERS | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/CODEOWNERS b/CODEOWNERS
index 51acc4986..4d5448636 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,4 +1,2 @@
# Admins
-* @comfyanonymous
-* @kosinkadink
-* @guill
+* @comfyanonymous @kosinkadink @guill
From 33d6aec3b70bc6f3e5bba26c85bd8f3bb1380d08 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 2 Dec 2025 21:50:13 +0200
Subject: [PATCH 09/81] add check for the format arg type in
VideoFromComponents.save_to function (#11046)
* add check for the format var type in VideoFromComponents.save_to function
* convert "format" to VideoContainer enum
---
comfy_api/latest/_input_impl/video_types.py | 2 +-
comfy_extras/nodes_video.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 7231bf13c..a4cd3737d 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -337,7 +337,7 @@ class VideoFromComponents(VideoInput):
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
extra_kwargs = {}
- if format != VideoContainer.AUTO:
+ if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index 69fabb12e..6cf6e39bf 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode):
)
@classmethod
- def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
+ def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode):
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
- format=format,
+ format=VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
From daaceac769a1355ab975758ede064317ea7514b4 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 2 Dec 2025 14:11:58 -0800
Subject: [PATCH 10/81] Hack to make zimage work in fp16. (#11057)
---
comfy/ldm/lumina/model.py | 18 +++++++++++-------
comfy/supported_models.py | 2 ++
2 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index 7d7e9112c..070b5da09 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -22,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
+def clamp_fp16(x):
+ if x.dtype == torch.float16:
+ return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+ return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@@ -169,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
- return F.silu(x1) * x3
+ return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
- self.feed_forward(
+ clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
- )
+ ))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + self.ffn_norm2(
self.feed_forward(
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index af8120400..afd97160b 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
memory_usage_factor = 1.7
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
From 277237ccc1499bac7fcd221a666dfe7a32ac4206 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 3 Dec 2025 08:24:19 +1000
Subject: [PATCH 11/81] attention: use flag based OOM fallback (#11038)
Exception ref all local variables for the lifetime of exception
context. Just set a flag and then if to dump the exception before
falling back.
---
comfy/ldm/modules/attention.py | 3 +++
comfy/ldm/modules/diffusionmodules/model.py | 3 +++
2 files changed, 6 insertions(+)
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 7437e0567..a8800ded0 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
+ exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
+ exception_fallback = True
+ if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 4245eedca..de1e01cc8 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -279,6 +279,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
+ oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -289,6 +290,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
+ oom_fallback = True
+ if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
From b94d394a64dd0af06bca44b96c66549bb463331d Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 2 Dec 2025 18:38:31 -0800
Subject: [PATCH 12/81] Support Z Image alibaba pai fun controlnets. (#11062)
These are not actual controlnets so put it in the models/model_patches
folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to
use it.
---
comfy/ldm/lumina/controlnet.py | 113 ++++++++++++++++++++++++++++++
comfy/ldm/lumina/model.py | 24 ++++---
comfy_extras/nodes_model_patch.py | 101 +++++++++++++++++++++++++-
3 files changed, 229 insertions(+), 9 deletions(-)
create mode 100644 comfy/ldm/lumina/controlnet.py
diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py
new file mode 100644
index 000000000..fd7ce3b5c
--- /dev/null
+++ b/comfy/ldm/lumina/controlnet.py
@@ -0,0 +1,113 @@
+import torch
+from torch import nn
+
+from .model import JointTransformerBlock
+
+class ZImageControlTransformerBlock(JointTransformerBlock):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ block_id=0,
+ operation_settings=None,
+ ):
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ return c_skip, c
+
+class ZImage_Control(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int = 3840,
+ n_heads: int = 30,
+ n_kv_heads: int = 30,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: float = (8.0 / 3.0),
+ norm_eps: float = 1e-5,
+ qk_norm: bool = True,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.additional_in_dim = 0
+ self.control_in_dim = 16
+ n_refiner_layers = 2
+ self.n_control_layers = 6
+ self.control_layers = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ i,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=i,
+ operation_settings=operation_settings,
+ )
+ for i in range(self.n_control_layers)
+ ]
+ )
+
+ all_x_embedder = {}
+ patch_size = 2
+ f_patch_size = 1
+ x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+
+ self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ z_image_modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
+ patch_size = 2
+ f_patch_size = 1
+ pH = pW = patch_size
+ B, C, H, W = control_context.shape
+ control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
+
+ x_attn_mask = None
+ for layer in self.control_noise_refiner:
+ control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+ return control_context
+
+ def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index 070b5da09..f1c1a0ec3 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -568,7 +568,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
- def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -585,16 +585,24 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+ patches = transformer_options.get("patches", {})
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
- freqs_cis = freqs_cis.to(x.device)
+ img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
+ freqs_cis = freqs_cis.to(img.device)
- for layer in self.layers:
- x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ for i, layer in enumerate(self.layers):
+ img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
+ if "img" in out:
+ img[:, cap_size[0]:] = out["img"]
+ if "txt" in out:
+ img[:, :cap_size[0]] = out["txt"]
- x = self.final_layer(x, adaln_input)
- x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
+ img = self.final_layer(img, adaln_input)
+ img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
- return -x
+ return -img
diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py
index 783c59b6b..c61810dbf 100644
--- a/comfy_extras/nodes_model_patch.py
+++ b/comfy_extras/nodes_model_patch.py
@@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
+import comfy.ldm.lumina.controlnet
class BlockWiseControlBlock(torch.nn.Module):
@@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding
+def z_image_convert(sd):
+ replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
+ ".attention.norm_k.weight": ".attention.k_norm.weight",
+ ".attention.norm_q.weight": ".attention.q_norm.weight",
+ ".attention.to_out.0.weight": ".attention.out.weight"
+ }
+
+ out_sd = {}
+ for k in sorted(sd.keys()):
+ w = sd[k]
+
+ k_out = k
+ if k_out.endswith(".attention.to_k.weight"):
+ cc = [w]
+ continue
+ if k_out.endswith(".attention.to_q.weight"):
+ cc = [w] + cc
+ continue
+ if k_out.endswith(".attention.to_v.weight"):
+ cc = cc + [w]
+ w = torch.cat(cc, dim=0)
+ k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
+
+ for r, rr in replace_keys.items():
+ k_out = k_out.replace(r, rr)
+ out_sd[k_out] = w
+
+ return out_sd
+
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -211,6 +241,9 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
+ elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
+ sd = z_image_convert(sd)
+ model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
def models(self):
return [self.model_patch]
+class ZImageControlPatch:
+ def __init__(self, model_patch, vae, image, strength):
+ self.model_patch = model_patch
+ self.vae = vae
+ self.image = image
+ self.strength = strength
+ self.encoded_image = self.encode_latent_cond(image)
+ self.encoded_image_size = (image.shape[1], image.shape[2])
+ self.temp_data = None
+
+ def encode_latent_cond(self, image):
+ latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
+ return latent_image
+
+ def __call__(self, kwargs):
+ x = kwargs.get("x")
+ img = kwargs.get("img")
+ txt = kwargs.get("txt")
+ pe = kwargs.get("pe")
+ vec = kwargs.get("vec")
+ block_index = kwargs.get("block_index")
+ spacial_compression = self.vae.spacial_compression_encode()
+ if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
+ image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
+ self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
+ self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
+ comfy.model_management.load_models_gpu(loaded_models)
+
+ cnet_index = (block_index // 5)
+ cnet_index_float = (block_index / 5)
+
+ kwargs.pop("img") # we do ops in place
+ kwargs.pop("txt")
+
+ cnet_blocks = self.model_patch.model.n_control_layers
+ if cnet_index_float > (cnet_blocks - 1):
+ self.temp_data = None
+ return kwargs
+
+ if self.temp_data is None or self.temp_data[0] > cnet_index:
+ self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
+
+ while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
+ next_layer = self.temp_data[0] + 1
+ self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
+
+ if cnet_index_float == self.temp_data[0]:
+ img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
+ if cnet_blocks == self.temp_data[0] + 1:
+ self.temp_data = None
+
+ return kwargs
+
+ def to(self, device_or_dtype):
+ if isinstance(device_or_dtype, torch.device):
+ self.encoded_image = self.encoded_image.to(device_or_dtype)
+ self.temp_data = None
+ return self
+
+ def models(self):
+ return [self.model_patch]
+
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
@@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
- model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
+ if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
+ model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
+ else:
+ model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
From 3f512f5659cfbb3c53999cde6ff557591740252b Mon Sep 17 00:00:00 2001
From: Jim Heising
Date: Tue, 2 Dec 2025 19:29:27 -0800
Subject: [PATCH 13/81] Added PATCH method to CORS headers (#11066)
Added PATCH http method to access-control-allow-header-methods header because there are now PATCH endpoints exposed in the API.
See https://github.com/comfyanonymous/ComfyUI/blob/277237ccc1499bac7fcd221a666dfe7a32ac4206/api_server/routes/internal/internal_routes.py#L34 for an example of an API endpoint that uses the PATCH method.
---
server.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/server.py b/server.py
index e3bd056d9..ac4f42222 100644
--- a/server.py
+++ b/server.py
@@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
- response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
From 73f5649196f472d3719e2e7513e0a9d029cc3e38 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 3 Dec 2025 13:49:29 +1000
Subject: [PATCH 14/81] Implement temporal rolling VAE (Major VRAM reductions
in Hunyuan and Kandinsky) (#10995)
* hunyuan upsampler: rework imports
Remove the transitive import of VideoConv3d and Resnet and takes these
from actual implementation source.
* model: remove unused give_pre_end
According to git grep, this is not used now, and was not used in the
initial commit that introduced it (see below).
This semantic is difficult to implement temporal roll VAE for (and would
defeat the purpose). Rather than implement the complex if, just delete
the unused feature.
(venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline
220afe33 (HEAD) Initial commit.
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end:
(venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master
Previous HEAD position was 220afe33 Initial commit.
HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953)
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end:
* move refiner VAE temporal roller to core
Move the carrying conv op to the common VAE code and give it a better
name. Roll the carry implementation logic for Resnet into the base
class and scrap the Hunyuan specific subclass.
* model: Add temporal roll to main VAE decoder
If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolloing VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).
* model: Add temporal roll to main VAE encoder
If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolling VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).
---
comfy/ldm/hunyuan_video/upsampler.py | 3 +-
comfy/ldm/hunyuan_video/vae_refiner.py | 94 +++------
comfy/ldm/modules/diffusionmodules/model.py | 207 ++++++++++++++------
3 files changed, 174 insertions(+), 130 deletions(-)
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index 9f5e91a59..85f515f67 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
+from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py
index 9f750dcc4..ddf77cd0e 100644
--- a/comfy/ldm/hunyuan_video/vae_refiner.py
+++ b/comfy/ldm/hunyuan_video/vae_refiner.py
@@ -1,42 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
-class NoPadConv3d(nn.Module):
- def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
- super().__init__()
- self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
-
- def forward(self, x):
- return self.conv(x)
-
-
-def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
-
- x = xl[0]
- xl.clear()
-
- if conv_carry_out is not None:
- to_push = x[:, :, -2:, :, :].clone()
- conv_carry_out.append(to_push)
-
- if isinstance(op, NoPadConv3d):
- if conv_carry_in is None:
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
- else:
- carry_len = conv_carry_in[0].shape[2]
- x = torch.cat([conv_carry_in.pop(0), x], dim=2)
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
-
- out = op(x)
-
- return out
-
class RMS_norm(nn.Module):
def __init__(self, dim):
@@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
- def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
- def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x
-class HunyuanRefinerResnetBlock(ResnetBlock):
- def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
- super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
-
- def forward(self, x, conv_carry_in=None, conv_carry_out=None):
- h = x
- h = [ self.swish(self.norm1(x)) ]
- h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- h = [ self.dropout(self.swish(self.norm2(h))) ]
- h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- if self.in_channels != self.out_channels:
- x = self.nin_shortcut(x)
-
- return x+h
-
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
+
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
@@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out
del x
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
if not self.refiner_vae:
if z.shape[-3] == 1:
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index de1e01cc8..681a55db5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
+def torch_cat_if_needed(xl, dim):
+ if len(xl) > 1:
+ return torch.cat(xl, dim)
+ else:
+ return xl[0]
+
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+class CarriedConv3d(nn.Module):
+ def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
+ super().__init__()
+ self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
+
+ x = xl[0]
+ xl.clear()
+
+ if isinstance(op, CarriedConv3d):
+ if conv_carry_in is None:
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
+ else:
+ carry_len = conv_carry_in[0].shape[2]
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
+ x = torch.cat([conv_carry_in.pop(0), x], dim=2)
+
+ if conv_carry_out is not None:
+ to_push = x[:, :, -2:, :, :].clone()
+ conv_carry_out.append(to_push)
+
+ out = op(x)
+
+ return out
+
+
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
- t = x.shape[2]
- if t > 1:
- a, b = x.split((1, t - 1), dim=2)
- del x
- b = interpolate_up(b, scale_factor)
- else:
- a = x
-
- a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
- if t > 1:
- x = torch.cat((a, b), dim=2)
- else:
- x = a
+ results = []
+ if conv_carry_in is None:
+ first = x[:, :, :1, :, :]
+ results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
+ x = x[:, :, 1:, :, :]
+ if x.shape[2] > 0:
+ results.append(interpolate_up(x, scale_factor))
+ x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
- x = self.conv(x)
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
- if x.ndim == 4:
+ if isinstance(self.conv, CarriedConv3d):
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+ elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
+ x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
- x = self.conv(x)
+ x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb=None):
+ def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
- h = self.swish(h)
- h = self.conv1(h)
+ h = [ self.swish(h) ]
+ h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
- h = self.dropout(h)
- h = self.conv2(h)
+ h = [ self.dropout(h) ]
+ h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
+ x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@@ -520,9 +555,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
+ if not attn_resolutions:
+ conv_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -535,6 +575,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
+ self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@@ -561,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
+ else:
+ self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
+ if time_compress is not None:
+ self.time_compress = time_compress
+
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@@ -590,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
- # downsampling
- h = self.conv_in(x)
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- if i_level != self.num_resolutions-1:
- h = self.down[i_level].downsample(h)
+
+ if self.carried:
+ xl = [x[:, :, :1, :, :]]
+ if x.shape[2] > self.time_compress:
+ tc = self.time_compress
+ xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
+ x = xl
+ else:
+ x = [x]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+
+ # downsampling
+ x1 = [ x1 ]
+ h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
+
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
+ if len(self.down[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.down[i_level].attn[i_block](h1)
+ if i_level != self.num_resolutions-1:
+ h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
+
+ out.append(h1)
+ conv_carry_in = conv_carry_out
+
+ h = torch_cat_if_needed(out, dim=2)
+ del out
# middle
h = self.mid.block_1(h, temb)
@@ -607,15 +680,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
+ h = [ nonlinearity(h) ]
+ h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@@ -629,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
- self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
- conv_out_op = VideoConv3d
+ if not attn_resolutions and resnet_op == ResnetBlock:
+ conv_op = CarriedConv3d
+ conv_out_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
+ conv_out_op = VideoConv3d
+
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -709,29 +788,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
- h = self.conv_in(z)
+ h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
+ if self.carried:
+ h = torch.split(h, 2, dim=2)
+ else:
+ h = [ h ]
+ out = []
+
+ conv_carry_in = None
+
# upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h, **kwargs)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
+ for i, h1 in enumerate(h):
+ conv_carry_out = []
+ if i == len(h) - 1:
+ conv_carry_out = None
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.up[i_level].attn[i_block](h1, **kwargs)
+ if i_level != 0:
+ h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
- # end
- if self.give_pre_end:
- return h
+ h1 = self.norm_out(h1)
+ h1 = [ nonlinearity(h1) ]
+ h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
+ if self.tanh_out:
+ h1 = torch.tanh(h1)
+ out.append(h1)
+ conv_carry_in = conv_carry_out
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h, **kwargs)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
+ out = torch_cat_if_needed(out, dim=2)
+
+ return out
From c120eee5bacca643062657d2a7efad83c7d4d828 Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Tue, 2 Dec 2025 21:17:13 -0800
Subject: [PATCH 15/81] Add MatchType, DynamicCombo, and Autogrow support to V3
Schema (#10832)
* Added output_matchtypes to generated json for v3, initial backend support for MatchType, created nodes_logic.py and added SwitchNode
* Fixed providing list of allowed_types
* Add workaround in validation.py for V3 Combo outputs not working as Combo inputs
* Make match type receive_type pass validation
* Also add MatchType check to input_type in validation - will likely trigger when connecting to non-lazy stuff
* Make sure this PR only has MatchType stuff
* Initial work on DynamicCombo
* Add get_dynamic function, not yet filled out correctly
* Mark Switch node as Beta
* Make sure other unfinished dynamic types are not accidentally used
* Send DynamicCombo.Option inputs in the same format as normal v1 inputs
* add dynamic combo test node
* Support validation of inputs and outputs
* Add missing input params to DynamicCombo.Input
* Add get_all function to inputs for id validation purposes
* Fix imports for v3 returning everything when doing io/ui/IO/UI instead of what is in __all__ of _io.py and _ui.py
* Modifying behavior of get_dynamic in V3 + serialization so can be used in execution code
* Fix v3 schema validation code after changes
* Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet)
* Add nesting of inputs on DynamicCombo during execution
* Work with latest frontend commits
* Fix cringe arrows
* frontend will no longer namespace dynamic inputs widgets so reflect that in code, refactor build_nested_inputs
* Prepare Autogrow support for the love of the game
* satisfy ruff
* Create test nodes for Autogrow to collab with frontend development
* Add nested combo to DCTestNode
* Remove array support from build_nested_inputs, properly handle missing expected values
* Make execution.validate_inputs properly validate required dynamic inputs, renamed dynamic_data to dynamic_paths for clarity
* MatchType does not need any DynamicInput/Output features on backend; will increase compatibility with dynamic types
* Probably need this for ruff check
* Change MatchType to have template be the first and only required param; output id's do nothing right now, so no need
* Fix merge regression with LatentUpscaleModel type not being put in __all__ for _io.py, fix invalid type hint for validate_inputs
* Make Switch node inputs optional, disallow both inputs from being missing, and still work properly with lazy; when one input is missing, use the other no matter what the switch is set to
* Satisfy ruff
* Move MatchType code above the types that inherit from DynamicInput
* Add DynamicSlot type, awaiting frontend support
* Make curr_prefix creation happen in Autogrow, move curr_prefix in DynamicCombo to only be created if input exists in live_inputs
* I was confused, fixing accidentally redundant curr_prefix addition in Autogrow
* Make sure Autogrow inputs are force_input = True when WidgetInput, fix runtime validation by removing original input from expected inputs, fix min/max bounds, change test nodes slightly
* Remove unnecessary id usage in Autogrow test node outputs
* Commented out Switch node + test nodes
* Remove commented out code from Autogrow
* Make TemplatePrefix max more clear, allow max == 1
* Replace all dict[str] with dict[str, Any]
* Renamed add_to_dict_live_inputs to expand_schema_for_dynamic
* Fixed typo in DynamicSlot input code
* note about live_inputs not being present soon in get_v1_info (internal function anyway)
* For now, hide DynamicCombo and Autogrow from public interface
* Removed comment
---
comfy_api/latest/__init__.py | 4 +-
comfy_api/latest/_io.py | 416 ++++++++++++++++++++++++++-------
comfy_api/latest/_io_public.py | 1 +
comfy_api/latest/_ui_public.py | 1 +
comfy_api/v0_0_2/__init__.py | 6 +-
comfy_execution/validation.py | 6 +
comfy_extras/nodes_logic.py | 155 ++++++++++++
execution.py | 40 ++--
nodes.py | 1 +
9 files changed, 525 insertions(+), 105 deletions(-)
create mode 100644 comfy_api/latest/_io_public.py
create mode 100644 comfy_api/latest/_ui_public.py
create mode 100644 comfy_extras/nodes_logic.py
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 176ae36e0..0fa01d1e7 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
-from . import _io as io
-from . import _ui as ui
+from . import _io_public as io
+from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 79c0722a9..257f07c42 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -4,6 +4,7 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
+from collections.abc import Iterable
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
@@ -150,6 +151,9 @@ class _IO_V3:
def __init__(self):
pass
+ def validate(self):
+ pass
+
@property
def io_type(self):
return self.Parent.io_type
@@ -182,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self):
return _StringIOType(self.io_type)
+ def get_all(self) -> list[Input]:
+ return [self]
+
class WidgetInput(Input):
'''
Base class for a V3 Input with widget.
@@ -814,13 +821,61 @@ class MultiType:
else:
return super().as_dict()
+@comfytype(io_type="COMFY_MATCHTYPE_V3")
+class MatchType(ComfyTypeIO):
+ class Template:
+ def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
+ self.template_id = template_id
+ # account for syntactic sugar
+ if not isinstance(allowed_types, Iterable):
+ allowed_types = [allowed_types]
+ for t in allowed_types:
+ if not isinstance(t, type):
+ if not isinstance(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
+ else:
+ if not issubclass(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
+ self.allowed_types = allowed_types
+
+ def as_dict(self):
+ return {
+ "template_id": self.template_id,
+ "allowed_types": ",".join([t.io_type for t in self.allowed_types]),
+ }
+
+ class Input(Input):
+ def __init__(self, id: str, template: MatchType.Template,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
+ class Output(Output):
+ def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
+ is_output_list=False):
+ super().__init__(id, display_name, tooltip, is_output_list)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
- @abstractmethod
def get_dynamic(self) -> list[Input]:
- ...
+ return []
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ pass
+
class DynamicOutput(Output, ABC):
'''
@@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
- @abstractmethod
def get_dynamic(self) -> list[Output]:
- ...
+ return []
@comfytype(io_type="COMFY_AUTOGROW_V3")
-class AutogrowDynamic(ComfyTypeI):
- Type = list[Any]
- class Input(DynamicInput):
- def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
- display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
- super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template_input = template_input
- if min is not None:
- assert(min >= 1)
- if max is not None:
- assert(max >= 1)
+class Autogrow(ComfyTypeI):
+ Type = dict[str, Any]
+ _MaxNames = 100 # NOTE: max 100 names for sanity
+
+ class _AutogrowTemplate:
+ def __init__(self, input: Input):
+ # dynamic inputs are not allowed as the template input
+ assert(not isinstance(input, DynamicInput))
+ self.input = copy.copy(input)
+ if isinstance(self.input, WidgetInput):
+ self.input.force_input = True
+ self.names: list[str] = []
+ self.cached_inputs = {}
+
+ def _create_input(self, input: Input, name: str):
+ new_input = copy.copy(self.input)
+ new_input.id = name
+ return new_input
+
+ def _create_cached_inputs(self):
+ for name in self.names:
+ self.cached_inputs[name] = self._create_input(self.input, name)
+
+ def get_all(self) -> list[Input]:
+ return list(self.cached_inputs.values())
+
+ def as_dict(self):
+ return prune_dict({
+ "input": create_input_dict_v1([self.input]),
+ })
+
+ def validate(self):
+ self.input.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ real_inputs = []
+ for name, input in self.cached_inputs.items():
+ if name in live_inputs:
+ real_inputs.append(input)
+ add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, real_inputs, curr_prefix)
+
+ class TemplatePrefix(_AutogrowTemplate):
+ def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
+ super().__init__(input)
+ self.prefix = prefix
+ assert(min >= 0)
+ assert(max >= 1)
+ assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
+ self.names = [f"{self.prefix}{i}" for i in range(self.max)]
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "prefix": self.prefix,
+ "min": self.min,
+ "max": self.max,
+ })
+
+ class TemplateNames(_AutogrowTemplate):
+ def __init__(self, input: Input, names: list[str], min: int=1):
+ super().__init__(input)
+ self.names = names[:Autogrow._MaxNames]
+ assert(min >= 0)
+ self.min = min
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "names": self.names,
+ "min": self.min,
+ })
+
+ class Input(DynamicInput):
+ def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
def get_dynamic(self) -> list[Input]:
- curr_count = 1
- new_inputs = []
- for i in range(self.min):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = self.optional or new_input.optional
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- # pretend to expand up to max
- for i in range(curr_count-1, self.max):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = True
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- return new_inputs
+ return self.template.get_all()
-@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
-class ComboDynamic(ComfyTypeI):
- class Input(DynamicInput):
- def __init__(self, id: str):
- pass
+ def get_all(self) -> list[Input]:
+ return [self] + self.template.get_all()
-@comfytype(io_type="COMFY_MATCHTYPE_V3")
-class MatchType(ComfyTypeIO):
- class Template:
- def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
- self.template_id = template_id
- self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
+ def validate(self):
+ self.template.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ curr_prefix = f"{curr_prefix}{self.id}."
+ # need to remove self from expected inputs dictionary; replaced by template inputs in frontend
+ for inner_dict in d.values():
+ if self.id in inner_dict:
+ del inner_dict[self.id]
+ self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+
+@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
+class DynamicCombo(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Option:
+ def __init__(self, key: str, inputs: list[Input]):
+ self.key = key
+ self.inputs = inputs
def as_dict(self):
return {
- "template_id": self.template_id,
- "allowed_types": "".join(t.io_type for t in self.allowed_types),
+ "key": self.key,
+ "inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
- def __init__(self, id: str, template: MatchType.Template,
+ def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template = template
+ self.options = options
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ # check if dynamic input's id is in live_inputs
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ key = live_inputs[self.id]
+ selected_option = None
+ for option in self.options:
+ if option.key == key:
+ selected_option = option
+ break
+ if selected_option is not None:
+ add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
- return [self]
+ return [input for option in self.options for input in option.inputs]
+
+ def get_all(self) -> list[Input]:
+ return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "options": [o.as_dict() for o in self.options],
})
- class Output(DynamicOutput):
- def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
- is_output_list=False):
- super().__init__(id, display_name, tooltip, is_output_list)
- self.template = template
+ def validate(self):
+ # make sure all nested inputs are validated
+ for option in self.options:
+ for input in option.inputs:
+ input.validate()
- def get_dynamic(self) -> list[Output]:
- return [self]
+@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
+class DynamicSlot(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Input(DynamicInput):
+ def __init__(self, slot: Input, inputs: list[Input],
+ display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ assert(not isinstance(slot, DynamicInput))
+ self.slot = copy.copy(slot)
+ self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
+ optional = True
+ self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
+ self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
+ self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
+ super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
+ self.inputs = inputs
+ self.force_input = None
+ # force widget inputs to have no widgets, otherwise this would be awkward
+ if isinstance(self.slot, WidgetInput):
+ self.force_input = True
+ self.slot.force_input = True
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
+
+ def get_dynamic(self) -> list[Input]:
+ return [self.slot] + self.inputs
+
+ def get_all(self) -> list[Input]:
+ return [self] + [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "slotType": str(self.slot.get_io_type()),
+ "inputs": create_input_dict_v1(self.inputs),
+ "forceInput": self.force_input,
})
+ def validate(self):
+ self.slot.validate()
+ for input in self.inputs:
+ input.validate()
+
+def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
+ dynamic = d.setdefault("dynamic_paths", {})
+ if self is not None:
+ dynamic[self.id] = f"{curr_prefix}{self.id}"
+ for i in inputs:
+ if not isinstance(i, DynamicInput):
+ dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
+
+class V3Data(TypedDict):
+ hidden_inputs: dict[str, Any]
+ dynamic_paths: dict[str, Any]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@@ -984,6 +1163,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
+ output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@@ -1061,7 +1241,11 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
- input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
+ nested_inputs: list[Input] = []
+ if self.inputs is not None:
+ for input in self.inputs:
+ nested_inputs.extend(input.get_all())
+ input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
@@ -1077,6 +1261,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
+ # validate inputs and outputs
+ if self.inputs is not None:
+ for input in self.inputs:
+ input.validate()
+ if self.outputs is not None:
+ for output in self.outputs:
+ output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@@ -1102,19 +1293,10 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
- def get_v1_info(self, cls) -> NodeInfoV1:
+ def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
+ # NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs
- input = {
- "required": {}
- }
- if self.inputs:
- for i in self.inputs:
- if isinstance(i, DynamicInput):
- dynamic_inputs = i.get_dynamic()
- for d in dynamic_inputs:
- add_to_dict_v1(d, input)
- else:
- add_to_dict_v1(i, input)
+ input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@@ -1123,12 +1305,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
+ output_matchtypes = []
+ any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
+ # special handling for MatchType
+ if isinstance(o, MatchType.Output):
+ output_matchtypes.append(o.template.template_id)
+ any_matchtypes = True
+ else:
+ output_matchtypes.append(None)
+
+ # clear out lists that are all None
+ if not any_matchtypes:
+ output_matchtypes = None
info = NodeInfoV1(
input=input,
@@ -1137,6 +1331,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
+ output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@@ -1182,16 +1377,57 @@ class Schema:
return info
-def add_to_dict_v1(i: Input, input: dict):
+def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
+ input = {
+ "required": {}
+ }
+ add_to_input_dict_v1(input, inputs, live_inputs)
+ return input
+
+def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
+ for i in inputs:
+ if isinstance(i, DynamicInput):
+ add_to_dict_v1(i, d)
+ if live_inputs is not None:
+ i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+ else:
+ add_to_dict_v1(i, d)
+
+def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
- input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
+ if dynamic_dict is None:
+ value = (i.get_io_type(), as_dict)
+ else:
+ value = (i.get_io_type(), as_dict, dynamic_dict)
+ d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
+def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
+ paths = v3_data.get("dynamic_paths", None)
+ if paths is None:
+ return values
+ values = values.copy()
+ result = {}
+
+ for key, path in paths.items():
+ parts = path.split(".")
+ current = result
+
+ for i, p in enumerate(parts):
+ is_last = (i == len(parts) - 1)
+
+ if is_last:
+ current[p] = values.pop(key, None)
+ else:
+ current = current.setdefault(p, {})
+
+ values.update(result)
+ return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
+ def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
- type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
+ type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone
@final
@@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
+ def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA()
- info = schema.get_v1_info(cls)
+ info = schema.get_v1_info(cls, live_inputs)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
- return input, schema
+ v3_data: V3Data = {}
+ dynamic = input.pop("dynamic_paths", None)
+ if dynamic is not None:
+ v3_data["dynamic_paths"] = dynamic
+ return input, schema, v3_data
return input
@final
@@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
- def validate_inputs(cls, **kwargs) -> bool:
+ def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@@ -1628,6 +1868,7 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
+ "LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@@ -1651,6 +1892,10 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
+ # Dynamic Types
+ "MatchType",
+ # "DynamicCombo",
+ # "Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@@ -1661,4 +1906,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
+ "V3Data",
]
diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py
new file mode 100644
index 000000000..43c7680f3
--- /dev/null
+++ b/comfy_api/latest/_io_public.py
@@ -0,0 +1 @@
+from ._io import * # noqa: F403
diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py
new file mode 100644
index 000000000..85b11d78b
--- /dev/null
+++ b/comfy_api/latest/_ui_public.py
@@ -0,0 +1 @@
+from ._ui import * # noqa: F403
diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py
index de0f95001..c4fa1d971 100644
--- a/comfy_api/v0_0_2/__init__.py
+++ b/comfy_api/v0_0_2/__init__.py
@@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
-from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
+from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@@ -42,4 +42,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
+ "io",
+ "IO",
+ "ui",
+ "UI",
]
diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py
index cec105fc9..24c0b4ed7 100644
--- a/comfy_execution/validation.py
+++ b/comfy_execution/validation.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+from comfy_api.latest import IO
def validate_node_input(
@@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type:
return True
+ # If the received type or input_type is a MatchType, we can return True immediately;
+ # validation for this is handled by the frontend
+ if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
+ return True
+
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False
diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py
new file mode 100644
index 000000000..95a6ba788
--- /dev/null
+++ b/comfy_extras/nodes_logic.py
@@ -0,0 +1,155 @@
+from typing import TypedDict
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+from comfy_api.latest import _io
+
+
+
+class SwitchNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = io.MatchType.Template("switch")
+ return io.Schema(
+ node_id="ComfySwitchNode",
+ display_name="Switch",
+ category="logic",
+ is_experimental=True,
+ inputs=[
+ io.Boolean.Input("switch"),
+ io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
+ io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
+ ],
+ outputs=[
+ io.MatchType.Output(template=template, display_name="output"),
+ ],
+ )
+
+ @classmethod
+ def check_lazy_status(cls, switch, on_false=..., on_true=...):
+ # We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
+ # This trick allows us to ignore the value of the switch and still be able to run execute().
+
+ # One of the inputs may be missing, in which case we need to evaluate the other input
+ if on_false is ...:
+ return ["on_true"]
+ if on_true is ...:
+ return ["on_false"]
+ # Normal lazy switch operation
+ if switch and on_true is None:
+ return ["on_true"]
+ if not switch and on_false is None:
+ return ["on_false"]
+
+ @classmethod
+ def validate_inputs(cls, switch, on_false=..., on_true=...):
+ # This check happens before check_lazy_status(), so we can eliminate the case where
+ # both inputs are missing.
+ if on_false is ... and on_true is ...:
+ return "At least one of on_false or on_true must be connected to Switch node"
+ return True
+
+ @classmethod
+ def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
+ if on_true is ...:
+ return io.NodeOutput(on_false)
+ if on_false is ...:
+ return io.NodeOutput(on_true)
+ return io.NodeOutput(on_true if switch else on_false)
+
+
+class DCTestNode(io.ComfyNode):
+ class DCValues(TypedDict):
+ combo: str
+ string: str
+ integer: int
+ image: io.Image.Type
+ subcombo: dict[str]
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="DCTestNode",
+ display_name="DCTest",
+ category="logic",
+ is_output_node=True,
+ inputs=[_io.DynamicCombo.Input("combo", options=[
+ _io.DynamicCombo.Option("option1", [io.String.Input("string")]),
+ _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
+ _io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
+ _io.DynamicCombo.Option("option4", [
+ _io.DynamicCombo.Input("subcombo", options=[
+ _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
+ _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
+ ])
+ ])]
+ )],
+ outputs=[io.AnyType.Output()],
+ )
+
+ @classmethod
+ def execute(cls, combo: DCValues) -> io.NodeOutput:
+ combo_val = combo["combo"]
+ if combo_val == "option1":
+ return io.NodeOutput(combo["string"])
+ elif combo_val == "option2":
+ return io.NodeOutput(combo["integer"])
+ elif combo_val == "option3":
+ return io.NodeOutput(combo["image"])
+ elif combo_val == "option4":
+ return io.NodeOutput(f"{combo['subcombo']}")
+ else:
+ raise ValueError(f"Invalid combo: {combo_val}")
+
+
+class AutogrowNamesTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
+ return io.Schema(
+ node_id="AutogrowNamesTestNode",
+ display_name="AutogrowNamesTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class AutogrowPrefixTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
+ return io.Schema(
+ node_id="AutogrowPrefixTestNode",
+ display_name="AutogrowPrefixTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class LogicExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ # SwitchNode,
+ # DCTestNode,
+ # AutogrowNamesTestNode,
+ # AutogrowPrefixTestNode,
+ ]
+
+async def comfy_entrypoint() -> LogicExtension:
+ return LogicExtension()
diff --git a/execution.py b/execution.py
index 17c77beab..c2186ac98 100644
--- a/execution.py
+++ b/execution.py
@@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
-from comfy_api.latest import io
+from comfy_api.latest import io, _io
class ExecutionResult(Enum):
@@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
- input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
+ input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
@@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
+ v3_data: io.V3Data = {}
if is_v3:
- valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
+ valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
- return input_data_all, missing_keys, hidden_inputs_v3
+ v3_data["hidden_inputs"] = hidden_inputs_v3
+ return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please
@@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
-async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
+async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
- class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
- class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
+ # in case of dynamic inputs, restructure inputs to expected nested dict
+ if v3_data is not None:
+ inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@@ -320,8 +325,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
-async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
- return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
+ return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
- input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
+ input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
- required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
+ required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
- output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+ output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
- class_inputs = obj_class.INPUT_TYPES()
- valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
-
errors = []
valid = True
validate_function_inputs = []
validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal):
+ class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
+ class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
@@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_has_kwargs = argspec.varkw is not None
received_types = {}
+ valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
+
for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
@@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
- input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
+ input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
- ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
+ ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
diff --git a/nodes.py b/nodes.py
index 4c910a34b..356aa63df 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2355,6 +2355,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
+ "nodes_logic.py",
"nodes_nop.py",
]
From 861817d22d2659099811b56005c9eaea18d64c73 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 2 Dec 2025 21:47:51 -0800
Subject: [PATCH 16/81] Fix issue with portable updater. (#11070)
This should fix the problem with the portable updater not working with portables created from a separate branch on the repo.
This does not affect any current portables who were all created on the master branch.
---
.ci/update_windows/update.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py
index 51a263203..59ece5130 100755
--- a/.ci/update_windows/update.py
+++ b/.ci/update_windows/update.py
@@ -66,8 +66,10 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
- print("pulling.") # noqa: T201
- pull(repo)
+ print("fetching.") # noqa: T201
+ for remote in repo.remotes:
+ if remote.name == "origin":
+ remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@@ -149,3 +151,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass
+
From 519c9411653df99761053c30e101816e0ca3c24b Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 3 Dec 2025 17:28:45 +1000
Subject: [PATCH 17/81] Prs/lora reservations (reduce massive Lora reservations
especially on Flux2) (#11069)
* mp: only count the offload cost of math once
This was previously bundling the combined weight storage and computation
cost
* ops: put all post async transfer compute on the main stream
Some models have massive weights that need either complex
dequantization or lora patching. Don't do these patchings on the offload
stream, instead do them on the main stream to syncrhonize the
potentially large vram spikes for these compute processes. This avoids
having to assume a worst case scenario of multiple offload streams
all spiking VRAM is parallel with whatever the main stream is doing.
---
comfy/model_patcher.py | 4 ++--
comfy/ops.py | 39 ++++++++++++++++++++++-----------------
2 files changed, 24 insertions(+), 19 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 3eac77275..df2d8e827 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -704,7 +704,7 @@ class ModelPatcher:
lowvram_weight = False
- potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
+ potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@@ -883,7 +883,7 @@ class ModelPatcher:
break
module_offload_mem, module_mem, n, m, params = unload
- potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
+ potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
diff --git a/comfy/ops.py b/comfy/ops.py
index 61a2f0754..eae434e68 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
- if bias_has_function:
- with wf_context:
- for f in s.bias_function:
- bias = f(bias)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ bias_a = bias
+ weight_a = weight
+
+ if s.bias is not None:
+ for f in s.bias_function:
+ bias = f(bias)
if weight_has_function or weight.dtype != dtype:
- with wf_context:
- weight = weight.to(dtype=dtype)
- if isinstance(weight, QuantizedTensor):
- weight = weight.dequantize()
- for f in s.weight_function:
- weight = f(weight)
+ weight = weight.to(dtype=dtype)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ for f in s.weight_function:
+ weight = f(weight)
- comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
- return weight, bias, offload_stream
+ return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
@@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
- if weight is not None:
- device = weight.device
+ os, weight_a, bias_a = offload_stream
+ if os is None:
+ return
+ if weight_a is not None:
+ device = weight_a.device
else:
- if bias is None:
+ if bias_a is None:
return
- device = bias.device
- offload_stream.wait_stream(comfy.model_management.current_stream(device))
+ device = bias_a.device
+ os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
From 19f2192d69d13445131b72ad1d87167f59b66fc4 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 3 Dec 2025 18:37:35 +0200
Subject: [PATCH 18/81] fix(V3-Schema): use empty list defaults for
Schema.inputs/outputs/hidden to avoid None issues (#11083)
---
comfy_api/latest/_io.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 257f07c42..866c3e0eb 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -5,7 +5,7 @@ import inspect
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Iterable
-from dataclasses import asdict, dataclass
+from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@@ -1199,9 +1199,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
- inputs: list[Input]=None
- outputs: list[Output]=None
- hidden: list[Hidden]=None
+ inputs: list[Input] = field(default_factory=list)
+ outputs: list[Output] = field(default_factory=list)
+ hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
From 87c104bfc1928f0b018a50f5867f425e10482929 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 3 Dec 2025 18:55:44 +0200
Subject: [PATCH 19/81] add support for "@image" reference format in Kling Omni
API nodes (#11082)
---
comfy_api_nodes/apis/kling_api.py | 30 +++++--
comfy_api_nodes/nodes_kling.py | 138 ++++++++++++++++++++++++++++--
2 files changed, 155 insertions(+), 13 deletions(-)
diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py
index 0a3b447c5..d8949f8ac 100644
--- a/comfy_api_nodes/apis/kling_api.py
+++ b/comfy_api_nodes/apis/kling_api.py
@@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel):
url: str | None = Field(None, description="URL for generated video")
-class TaskStatusVideoResults(BaseModel):
+class TaskStatusImageResult(BaseModel):
+ index: int = Field(..., description="Image Number,0-9")
+ url: str = Field(..., description="URL for generated image")
+
+
+class OmniTaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
+ images: list[TaskStatusImageResult] | None = Field(None)
-class TaskStatusVideoResponseData(BaseModel):
+class OmniTaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
- task_result: TaskStatusVideoResults | None = Field(None)
+ task_result: OmniTaskStatusResults | None = Field(None)
-class TaskStatusVideoResponse(BaseModel):
+class OmniTaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
- data: TaskStatusVideoResponseData | None = Field(None)
+ data: OmniTaskStatusResponseData | None = Field(None)
+
+
+class OmniImageParamImage(BaseModel):
+ image: str = Field(...)
+
+
+class OmniProImageRequest(BaseModel):
+ model_name: str = Field(..., description="kling-image-o1")
+ resolution: str = Field(..., description="'1k' or '2k'")
+ aspect_ratio: str | None = Field(...)
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+ n: int | None = Field(1, le=9)
+ image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 850c44db6..6c840dc47 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
import logging
import math
+import re
import torch
from typing_extensions import override
@@ -49,12 +50,14 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
+ OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
+ OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
- TaskStatusVideoResponse,
+ OmniTaskStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@@ -210,7 +213,36 @@ VOICES_CONFIG = {
}
-async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
+def normalize_omni_prompt_references(prompt: str) -> str:
+ """
+ Rewrites Kling Omni-style placeholders used in the app, like:
+
+ @image, @image1, @image2, ... @imageN
+ @video, @video1, @video2, ... @videoN
+
+ into the API-compatible form:
+
+ <<>>, <<>>, ...
+ <<>>, <<>>, ...
+
+ This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
+ """
+ if not prompt:
+ return prompt
+
+ def _image_repl(match):
+ return f"<<>>"
+
+ def _video_repl(match):
+ return f"<<>>"
+
+ # (? and not @imageFoo
+ prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt)
+ return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt)
+
+
+async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
+ max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
@@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
@@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=TaskStatusVideoResponse,
+ response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode):
return await finish_omni_video_task(cls, response)
+class OmniProImageNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageNode",
+ display_name="Kling Omni Image (Pro)",
+ category="api node/image/Kling",
+ description="Create or edit images with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the image content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
+ ),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 10 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ resolution: str,
+ aspect_ratio: str,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ image_list: list[OmniImageParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 10:
+ raise ValueError("The maximum number of reference images is 10.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniImageParamImage(image=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProImageRequest(
+ model_name=model_name,
+ prompt=prompt,
+ resolution=resolution.lower(),
+ aspect_ratio=aspect_ratio,
+ image_list=image_list if image_list else None,
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
+ response_model=OmniTaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+
+
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
+ # OmniProImageNode, # need support from backend
]
From 440268d3940eb14a01595439bbc05c4aacde9c72 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 3 Dec 2025 23:52:31 +0200
Subject: [PATCH 20/81] convert nodes_load_3d.py to V3 schema (#10990)
---
comfy_api/latest/_ui.py | 13 +++-
comfy_extras/nodes_load_3d.py | 127 ++++++++++++++++------------------
2 files changed, 71 insertions(+), 69 deletions(-)
diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py
index b0bbabe2a..6d1bea599 100644
--- a/comfy_api/latest/_ui.py
+++ b/comfy_api/latest/_ui.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import os
import random
+import uuid
from io import BytesIO
from typing import Type
@@ -436,9 +437,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
+ self.bg_image_path = None
+ bg_image = kwargs.get("bg_image", None)
+ if bg_image is not None:
+ img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
+ img = PILImage.fromarray(img_array)
+ temp_dir = folder_paths.get_temp_directory()
+ filename = f"bg_{uuid.uuid4().hex}.png"
+ bg_image_path = os.path.join(temp_dir, filename)
+ img.save(bg_image_path, compress_level=1)
+ self.bg_image_path = f"temp/{filename}"
def as_dict(self):
- return {"result": [self.model_file, self.camera_info]}
+ return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput):
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 54c66ef68..545588ef8 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -2,22 +2,18 @@ import nodes
import folder_paths
import os
-from comfy.comfy_types import IO
-from comfy_api.input_impl import VideoFromFile
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
from pathlib import Path
-from PIL import Image
-import numpy as np
-
-import uuid
def normalize_path(path):
return path.replace('\\', '/')
-class Load3D():
+class Load3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
@@ -30,23 +26,29 @@ class Load3D():
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
]
+ return IO.Schema(
+ node_id="Load3D",
+ display_name="Load 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ inputs=[
+ IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
+ IO.Load3D.Input("image"),
+ IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
+ IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
+ ],
+ outputs=[
+ IO.Image.Output(display_name="image"),
+ IO.Mask.Output(display_name="mask"),
+ IO.String.Output(display_name="mesh_path"),
+ IO.Image.Output(display_name="normal"),
+ IO.Load3DCamera.Output(display_name="camera_info"),
+ IO.Video.Output(display_name="recording_video"),
+ ],
+ )
- return {"required": {
- "model_file": (sorted(files), {"file_upload": True}),
- "image": ("LOAD_3D", {}),
- "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- }}
-
- RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
- RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- CATEGORY = "3d"
-
- def process(self, model_file, image, **kwargs):
+ @classmethod
+ def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
@@ -61,58 +63,47 @@ class Load3D():
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
- video = VideoFromFile(recording_video_path)
+ video = InputImpl.VideoFromFile(recording_video_path)
- return output_image, output_mask, model_file, normal_image, image['camera_info'], video
+ return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
-class Preview3D():
+ process = execute # TODO: remove
+
+
+class Preview3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "model_file": ("STRING", {"default": "", "multiline": False}),
- },
- "optional": {
- "camera_info": ("LOAD3D_CAMERA", {}),
- "bg_image": ("IMAGE", {})
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Preview3D",
+ display_name="Preview 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ is_output_node=True,
+ inputs=[
+ IO.String.Input("model_file", default="", multiline=False),
+ IO.Load3DCamera.Input("camera_info", optional=True),
+ IO.Image.Input("bg_image", optional=True),
+ ],
+ outputs=[],
+ )
- OUTPUT_NODE = True
- RETURN_TYPES = ()
-
- CATEGORY = "3d"
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- def process(self, model_file, **kwargs):
+ @classmethod
+ def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
+ return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
- bg_image_path = None
- if bg_image is not None:
+ process = execute # TODO: remove
- img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
- img = Image.fromarray(img_array)
- temp_dir = folder_paths.get_temp_directory()
- filename = f"bg_{uuid.uuid4().hex}.png"
- bg_image_path = os.path.join(temp_dir, filename)
- img.save(bg_image_path, compress_level=1)
+class Load3DExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ Load3D,
+ Preview3D,
+ ]
- bg_image_path = f"temp/{filename}"
- return {
- "ui": {
- "result": [model_file, camera_info, bg_image_path]
- }
- }
-
-NODE_CLASS_MAPPINGS = {
- "Load3D": Load3D,
- "Preview3D": Preview3D,
-}
-
-NODE_DISPLAY_NAME_MAPPINGS = {
- "Load3D": "Load 3D & Animation",
- "Preview3D": "Preview 3D & Animation",
-}
+async def comfy_entrypoint() -> Load3DExtension:
+ return Load3DExtension()
From dce518c2b4f99634b5fdde1924d9b0bd468fe1ce Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Thu, 4 Dec 2025 03:35:04 +0200
Subject: [PATCH 21/81] convert nodes_audio.py to V3 schema (#10798)
---
comfy_api/latest/_ui.py | 9 +-
comfy_extras/nodes_audio.py | 744 ++++++++++++++++++------------------
2 files changed, 382 insertions(+), 371 deletions(-)
diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py
index 6d1bea599..5a75a3aae 100644
--- a/comfy_api/latest/_ui.py
+++ b/comfy_api/latest/_ui.py
@@ -319,9 +319,10 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
+ layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate)
+ out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@@ -333,7 +334,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
+ out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@@ -342,12 +343,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate)
+ out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
- layout="mono" if waveform.shape[0] == 1 else "stereo",
+ layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 2ed7e0b22..812301fb7 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -6,65 +6,80 @@ import torch
import comfy.model_management
import folder_paths
import os
-import io
-import json
-import random
import hashlib
import node_helpers
import logging
-from comfy.cli_args import args
-from comfy.comfy_types import FileLocator
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO, UI
-class EmptyLatentAudio:
- def __init__(self):
- self.device = comfy.model_management.intermediate_device()
+class EmptyLatentAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyLatentAudio",
+ display_name="Empty Latent Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
+ IO.Int.Input(
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
+ ),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
- "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
- }}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "generate"
-
- CATEGORY = "latent/audio"
-
- def generate(self, seconds, batch_size):
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2
- latent = torch.zeros([batch_size, 64, length], device=self.device)
- return ({"samples":latent, "type": "audio"}, )
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
+ return IO.NodeOutput({"samples":latent, "type": "audio"})
-class ConditioningStableAudio:
+ generate = execute # TODO: remove
+
+
+class ConditioningStableAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ConditioningStableAudio",
+ category="conditioning",
+ inputs=[
+ IO.Conditioning.Input("positive"),
+ IO.Conditioning.Input("negative"),
+ IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
+ IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
+ ],
+ outputs=[
+ IO.Conditioning.Output(display_name="positive"),
+ IO.Conditioning.Output(display_name="negative"),
+ ],
+ )
- RETURN_TYPES = ("CONDITIONING","CONDITIONING")
- RETURN_NAMES = ("positive", "negative")
-
- FUNCTION = "append"
-
- CATEGORY = "conditioning"
-
- def append(self, positive, negative, seconds_start, seconds_total):
+ @classmethod
+ def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
- return (positive, negative)
+ return IO.NodeOutput(positive, negative)
-class VAEEncodeAudio:
+ append = execute # TODO: remove
+
+
+class VAEEncodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "encode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEEncodeAudio",
+ display_name="VAE Encode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def encode(self, vae, audio):
+ @classmethod
+ def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
@@ -72,213 +87,134 @@ class VAEEncodeAudio:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
- return ({"samples":t}, )
+ return IO.NodeOutput({"samples":t})
-class VAEDecodeAudio:
+ encode = execute # TODO: remove
+
+
+class VAEDecodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "decode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEDecodeAudio",
+ display_name="VAE Decode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Latent.Input("samples"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def decode(self, vae, samples):
+ @classmethod
+ def execute(cls, vae, samples) -> IO.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
- return ({"waveform": audio, "sample_rate": 44100}, )
+ return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
+
+ decode = execute # TODO: remove
-def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
-
- filename_prefix += self.prefix_append
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
- results: list[FileLocator] = []
-
- # Prepare metadata dictionary
- metadata = {}
- if not args.disable_metadata:
- if prompt is not None:
- metadata["prompt"] = json.dumps(prompt)
- if extra_pnginfo is not None:
- for x in extra_pnginfo:
- metadata[x] = json.dumps(extra_pnginfo[x])
-
- # Opus supported sample rates
- OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
-
- for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
- filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
- file = f"{filename_with_batch_num}_{counter:05}_.{format}"
- output_path = os.path.join(full_output_folder, file)
-
- # Use original sample rate initially
- sample_rate = audio["sample_rate"]
-
- # Handle Opus sample rate requirements
- if format == "opus":
- if sample_rate > 48000:
- sample_rate = 48000
- elif sample_rate not in OPUS_RATES:
- # Find the next highest supported rate
- for rate in sorted(OPUS_RATES):
- if rate > sample_rate:
- sample_rate = rate
- break
- if sample_rate not in OPUS_RATES: # Fallback if still not supported
- sample_rate = 48000
-
- # Resample if necessary
- if sample_rate != audio["sample_rate"]:
- waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
-
- # Create output with specified format
- output_buffer = io.BytesIO()
- output_container = av.open(output_buffer, mode='w', format=format)
-
- # Set metadata on the container
- for key, value in metadata.items():
- output_container.metadata[key] = value
-
- layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
- # Set up the output stream with appropriate properties
- if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
- if quality == "64k":
- out_stream.bit_rate = 64000
- elif quality == "96k":
- out_stream.bit_rate = 96000
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "192k":
- out_stream.bit_rate = 192000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
- if quality == "V0":
- #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
- out_stream.codec_context.qscale = 1
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- else: #format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
-
- frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
- frame.sample_rate = sample_rate
- frame.pts = 0
- output_container.mux(out_stream.encode(frame))
-
- # Flush encoder
- output_container.mux(out_stream.encode(None))
-
- # Close containers
- output_container.close()
-
- # Write the output to file
- output_buffer.seek(0)
- with open(output_path, 'wb') as f:
- f.write(output_buffer.getbuffer())
-
- results.append({
- "filename": file,
- "subfolder": subfolder,
- "type": self.type
- })
- counter += 1
-
- return { "ui": { "audio": results } }
-
-class SaveAudio:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudio",
+ display_name="Save Audio (FLAC)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_flac"
+ save_flac = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
-
-class SaveAudioMP3:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioMP3(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioMP3",
+ display_name="Save Audio (MP3)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["V0", "128k", "320k"], {"default": "V0"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_mp3"
+ save_mp3 = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class SaveAudioOpus:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioOpus(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioOpus",
+ display_name="Save Audio (Opus)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_opus"
+ save_opus = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class PreviewAudio(SaveAudio):
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
- self.type = "temp"
- self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
+class PreviewAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="PreviewAudio",
+ display_name="Preview Audio",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"audio": ("AUDIO", ), },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio) -> IO.NodeOutput:
+ return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
+
+ save_flac = execute # TODO: remove
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
@@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav)
return wav, sr
-class LoadAudio:
+class LoadAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
- return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
+ return IO.Schema(
+ node_id="LoadAudio",
+ display_name="Load Audio",
+ category="audio",
+ inputs=[
+ IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
@classmethod
- def IS_CHANGED(s, audio):
+ def fingerprint_inputs(cls, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
@@ -343,46 +283,69 @@ class LoadAudio:
return m.digest().hex()
@classmethod
- def VALIDATE_INPUTS(s, audio):
+ def validate_inputs(cls, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
-class RecordAudio:
+ load = execute # TODO: remove
+
+
+class RecordAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"audio": ("AUDIO_RECORD", {})}}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecordAudio",
+ display_name="Record Audio",
+ category="audio",
+ inputs=[
+ IO.Custom("AUDIO_RECORD").Input("audio"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
+
+ load = execute # TODO: remove
-class TrimAudioDuration:
+class TrimAudioDuration(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio": ("AUDIO",),
- "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TrimAudioDuration",
+ display_name="Trim Audio Duration",
+ description="Trim audio tensor into chosen time range.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Float.Input(
+ "start_index",
+ default=0.0,
+ min=-0xffffffffffffffff,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
+ ),
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ step=0.01,
+ tooltip="Duration in seconds",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "trim"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Trim audio tensor into chosen time range."
-
- def trim(self, audio, start_index, duration):
+ @classmethod
+ def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
@@ -399,23 +362,30 @@ class TrimAudioDuration:
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
- return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
+
+ trim = execute # TODO: remove
-class SplitAudioChannels:
+class SplitAudioChannels(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SplitAudioChannels",
+ display_name="Split Audio Channels",
+ description="Separates the audio into left and right channels.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ outputs=[
+ IO.Audio.Output(display_name="left"),
+ IO.Audio.Output(display_name="right"),
+ ],
+ )
- RETURN_TYPES = ("AUDIO", "AUDIO")
- RETURN_NAMES = ("left", "right")
- FUNCTION = "separate"
- CATEGORY = "audio"
- DESCRIPTION = "Separates the audio into left and right channels."
-
- def separate(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
@@ -425,7 +395,9 @@ class SplitAudioChannels:
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
- return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+ return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+
+ separate = execute # TODO: remove
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate
-class AudioConcat:
+class AudioConcat(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioConcat",
+ display_name="Audio Concat",
+ description="Concatenates the audio1 to audio2 in the specified direction.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "direction",
+ options=['after', 'before'],
+ default="after",
+ tooltip="Whether to append audio2 after or before audio1.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "concat"
- CATEGORY = "audio"
- DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
-
- def concat(self, audio1, audio2, direction):
+ @classmethod
+ def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -477,26 +457,33 @@ class AudioConcat:
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
- return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
+
+ concat = execute # TODO: remove
-class AudioMerge:
+class AudioMerge(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioMerge",
+ display_name="Audio Merge",
+ description="Combine two audio tracks by overlaying their waveforms.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "merge_method",
+ options=["add", "mean", "subtract", "multiply"],
+ tooltip="The method used to combine the audio waveforms.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "merge"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
-
- def merge(self, audio1, audio2, merge_method):
+ @classmethod
+ def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -530,85 +517,108 @@ class AudioMerge:
if max_val > 1.0:
waveform = waveform / max_val
- return ({"waveform": waveform, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
+
+ merge = execute # TODO: remove
-class AudioAdjustVolume:
+class AudioAdjustVolume(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioAdjustVolume",
+ display_name="Audio Adjust Volume",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Int.Input(
+ "volume",
+ default=1,
+ min=-100,
+ max=100,
+ tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "adjust_volume"
- CATEGORY = "audio"
-
- def adjust_volume(self, audio, volume):
+ @classmethod
+ def execute(cls, audio, volume) -> IO.NodeOutput:
if volume == 0:
- return (audio,)
+ return IO.NodeOutput(audio)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ adjust_volume = execute # TODO: remove
-class EmptyAudio:
+class EmptyAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
- "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
- "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyAudio",
+ display_name="Empty Audio",
+ category="audio",
+ inputs=[
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Duration of the empty audio clip in seconds",
+ ),
+ IO.Float.Input(
+ "sample_rate",
+ default=44100,
+ tooltip="Sample rate of the empty audio clip.",
+ ),
+ IO.Float.Input(
+ "channels",
+ default=2,
+ min=1,
+ max=2,
+ tooltip="Number of audio channels (1 for mono, 2 for stereo).",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "create_empty_audio"
- CATEGORY = "audio"
-
- def create_empty_audio(self, duration, sample_rate, channels):
+ @classmethod
+ def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ create_empty_audio = execute # TODO: remove
-NODE_CLASS_MAPPINGS = {
- "EmptyLatentAudio": EmptyLatentAudio,
- "VAEEncodeAudio": VAEEncodeAudio,
- "VAEDecodeAudio": VAEDecodeAudio,
- "SaveAudio": SaveAudio,
- "SaveAudioMP3": SaveAudioMP3,
- "SaveAudioOpus": SaveAudioOpus,
- "LoadAudio": LoadAudio,
- "PreviewAudio": PreviewAudio,
- "ConditioningStableAudio": ConditioningStableAudio,
- "RecordAudio": RecordAudio,
- "TrimAudioDuration": TrimAudioDuration,
- "SplitAudioChannels": SplitAudioChannels,
- "AudioConcat": AudioConcat,
- "AudioMerge": AudioMerge,
- "AudioAdjustVolume": AudioAdjustVolume,
- "EmptyAudio": EmptyAudio,
-}
+class AudioExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ EmptyLatentAudio,
+ VAEEncodeAudio,
+ VAEDecodeAudio,
+ SaveAudio,
+ SaveAudioMP3,
+ SaveAudioOpus,
+ LoadAudio,
+ PreviewAudio,
+ ConditioningStableAudio,
+ RecordAudio,
+ TrimAudioDuration,
+ SplitAudioChannels,
+ AudioConcat,
+ AudioMerge,
+ AudioAdjustVolume,
+ EmptyAudio,
+ ]
-NODE_DISPLAY_NAME_MAPPINGS = {
- "EmptyLatentAudio": "Empty Latent Audio",
- "VAEEncodeAudio": "VAE Encode Audio",
- "VAEDecodeAudio": "VAE Decode Audio",
- "PreviewAudio": "Preview Audio",
- "LoadAudio": "Load Audio",
- "SaveAudio": "Save Audio (FLAC)",
- "SaveAudioMP3": "Save Audio (MP3)",
- "SaveAudioOpus": "Save Audio (Opus)",
- "RecordAudio": "Record Audio",
- "TrimAudioDuration": "Trim Audio Duration",
- "SplitAudioChannels": "Split Audio Channels",
- "AudioConcat": "Audio Concat",
- "AudioMerge": "Audio Merge",
- "AudioAdjustVolume": "Audio Adjust Volume",
- "EmptyAudio": "Empty Audio",
-}
+async def comfy_entrypoint() -> AudioExtension:
+ return AudioExtension()
From ecdc8697d53919a9178bf53ef327a110582db8ea Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 3 Dec 2025 19:49:28 -0800
Subject: [PATCH 22/81] Qwen Image Lora training fix from #11090 (#11094)
---
comfy_extras/nodes_train.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index cb24ab709..19b8baaf4 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode):
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
- latents = latents[0].repeat(num_images, 1, 1, 1)
+ latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
From ea17add3c62197b10fd0b71d9169d339adc55c47 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 3 Dec 2025 20:15:15 -0800
Subject: [PATCH 23/81] Fix case where text encoders where running on the CPU
instead of GPU. (#11095)
---
comfy/sd.py | 2 ++
comfy/sd1_clip.py | 9 ++++++++-
2 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index f9e5efab5..734bd2845 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -193,6 +193,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@@ -240,6 +241,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 0fc9ab3db..503a51843 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
+ self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
+ self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
+ self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
- device = self.transformer.get_input_embeddings().weight.device
+ if self.execution_device is None:
+ device = self.transformer.get_input_embeddings().weight.device
+ else:
+ device = self.execution_device
+
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
From 6be85c7920224b45bbc6417e00147815e78c12a9 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 4 Dec 2025 14:28:44 +1000
Subject: [PATCH 24/81] mp: use look-ahead actuals for stream offload VRAM
calculation (#11096)
TIL that the WAN TE has a 2GB weight followed by 16MB as the next size
down. This means that team 8GB VRAM would fully offload the TE in async
offload mode as it just multiplied this giant size my the num streams.
Do the more complex logic of summing up the upcoming to-load weight
sizes to avoid triple counting this massive weight.
partial unload does the converse of recording the NS most recent
unloads as they go.
---
comfy/model_patcher.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index df2d8e827..3dcac3eef 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -699,12 +699,12 @@ class ModelPatcher:
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
- for x in loading:
+ for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
- potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
+ potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@@ -876,14 +876,18 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
+
offload_buffer = self.model.model_offload_buffer_memory
+ if len(unload_list) > 0:
+ NS = comfy.model_management.NUM_STREAMS
+ offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
- potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
+ potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -935,6 +939,8 @@ class ModelPatcher:
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
+ offload_weight_factor.append(module_mem)
+ offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params:
From f4bdf5f8302ef10db99644a8672e614ddb29c473 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Fri, 5 Dec 2025 03:50:04 +1000
Subject: [PATCH 25/81] sd: revise hy VAE VRAM (#11105)
This was recently collapsed down to rolling VAE through temporal. Clamp
The time dimension.
---
comfy/sd.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index 734bd2845..fe4dd65f8 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -483,8 +483,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
- self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
- self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ #This is likely to significantly over-estimate with single image or low frame counts as the
+ #implementation is able to completely skip caching. Rework if used as an image only VAE
+ self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
From 9bc893c5bbd2838bdd15ebd40e3a3e548ce3e4f0 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Fri, 5 Dec 2025 03:50:36 +1000
Subject: [PATCH 26/81] sd: bump HY1.5 VAE estimate (#11107)
Im able to push vram above estimate on partial unload. Bump the
estimate. This is experimentally determined with a 720P and 480P
datapoint calibrating for 24GB VRAM total.
---
comfy/sd.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index fe4dd65f8..03bdb33d5 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -471,7 +471,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
- self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
From 3c8456223c5f6a41af7d99219b391c8c58acb552 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Fri, 5 Dec 2025 00:05:28 +0200
Subject: [PATCH 27/81] [API Nodes]: fixes and refactor (#11104)
* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder
* chore(api-nodes): add validate_video_frame_count function from LTX PR
* chore(api-nodes): replace deprecated V1 imports
* fix(api-nodes): the types returned by the "poll_op" function are now correct.
---
comfy_api_nodes/util/__init__.py | 2 +
comfy_api_nodes/util/_helpers.py | 14 +--
comfy_api_nodes/util/client.py | 145 ++++++++++++-----------
comfy_api_nodes/util/conversions.py | 21 ++--
comfy_api_nodes/util/download_helpers.py | 20 ++--
comfy_api_nodes/util/request_logger.py | 2 -
comfy_api_nodes/util/upload_helpers.py | 16 ++-
comfy_api_nodes/util/validation_utils.py | 61 ++++++----
8 files changed, 146 insertions(+), 135 deletions(-)
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 80292fb3c..4cc22abfb 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -47,6 +47,7 @@ from .validation_utils import (
validate_string,
validate_video_dimensions,
validate_video_duration,
+ validate_video_frame_count,
)
__all__ = [
@@ -94,6 +95,7 @@ __all__ = [
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
+ "validate_video_frame_count",
# Misc functions
"get_fs_object_size",
]
diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py
index 328fe5227..491e6b6a8 100644
--- a/comfy_api_nodes/util/_helpers.py
+++ b/comfy_api_nodes/util/_helpers.py
@@ -2,8 +2,8 @@ import asyncio
import contextlib
import os
import time
+from collections.abc import Callable
from io import BytesIO
-from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
@@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt(
seconds: float,
- node_cls: Optional[type[IO.ComfyNode]],
- label: Optional[str] = None,
- start_ts: Optional[float] = None,
- estimated_total: Optional[int] = None,
+ node_cls: type[IO.ComfyNode] | None,
+ label: str | None = None,
+ start_ts: float | None = None,
+ estimated_total: int | None = None,
*,
- display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
+ display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
):
"""
Sleep in 1s slices while:
@@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
-def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
+def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index bf01d7d36..bf37cba5f 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -4,10 +4,11 @@ import json
import logging
import time
import uuid
+from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
-from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
+from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse
import aiohttp
@@ -37,8 +38,8 @@ class ApiEndpoint:
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
- query_params: Optional[dict[str, Any]] = None,
- headers: Optional[dict[str, str]] = None,
+ query_params: dict[str, Any] | None = None,
+ headers: dict[str, str] | None = None,
):
self.path = path
self.method = method
@@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint
timeout: float
content_type: str
- data: Optional[dict[str, Any]]
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
- multipart_parser: Optional[Callable]
+ data: dict[str, Any] | None
+ files: dict[str, Any] | list[tuple[str, Any]] | None
+ multipart_parser: Callable | None
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
- estimated_total: Optional[int] = None
- final_label_on_success: Optional[str] = "Completed"
- progress_origin_ts: Optional[float] = None
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
+ estimated_total: int | None = None
+ final_label_on_success: str | None = "Completed"
+ progress_origin_ts: float | None = None
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass
@@ -71,10 +72,10 @@ class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
- price: Optional[float] = None
- estimated_duration: Optional[int] = None
+ price: float | None = None
+ estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
- active_since: Optional[float] = None # start time of current active interval (None if queued)
+ active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- data: Optional[BaseModel] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ response_model: type[M],
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ data: BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ estimated_duration: int | None = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
@@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- status_extractor: Callable[[M], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[BaseModel] = None,
+ response_model: type[M],
+ status_extractor: Callable[[M | Any], str | int | None],
+ progress_extractor: Callable[[M | Any], int | None] | None = None,
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
@@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
+ estimated_duration: int | None = None,
as_binary: bool = False,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
-) -> Union[dict[str, Any], bytes]:
+) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
+ status_extractor: Callable[[dict[str, Any]], str | int | None],
+ progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
@@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
- last_progress: Optional[int] = None
+ last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
@@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text(
node_cls: type[IO.ComfyNode],
- text: Optional[str],
+ text: str | None,
*,
- status: Optional[Union[str, int]] = None,
- price: Optional[float] = None,
+ status: str | int | None = None,
+ price: float | None = None,
) -> None:
display_lines: list[str] = []
if status:
@@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress(
node_cls: type[IO.ComfyNode],
- status: Optional[Union[str, int]],
+ status: str | int | None,
elapsed_seconds: int,
- estimated_total: Optional[int] = None,
+ estimated_total: int | None = None,
*,
- price: Optional[float] = None,
- is_queued: Optional[bool] = None,
- processing_elapsed_seconds: Optional[int] = None,
+ price: float | None = None,
+ is_queued: bool | None = None,
+ processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
-def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
+def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
- data: Optional[dict[str, Any]],
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
-) -> Optional[Union[dict[str, Any], str]]:
+ data: dict[str, Any] | None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None,
+) -> dict[str, Any] | str | None:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
@@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
- final_elapsed_seconds: Optional[int] = None
- extracted_price: Optional[float] = None
+ final_elapsed_seconds: int | None = None
+ extracted_price: float | None = None
while True:
attempt += 1
stop_event = asyncio.Event()
- monitor_task: Optional[asyncio.Task] = None
- sess: Optional[aiohttp.ClientSession] = None
+ monitor_task: asyncio.Task | None = None
+ sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
)
-def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
+def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor(
- response_model: Type[M],
- extractor: Optional[Callable[[M], Any]],
-) -> Optional[Callable[[dict[str, Any]], Any]]:
+ response_model: type[M],
+ extractor: Callable[[M], Any] | None,
+) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped
-def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
+def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values:
return set()
- out: set[Union[str, int]] = set()
+ out: set[str | int] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
@@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out
-def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
+def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str):
return val.strip().lower()
return val
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 971dc57de..c57457580 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -4,7 +4,6 @@ import math
import mimetypes
import uuid
from io import BytesIO
-from typing import Optional
import av
import numpy as np
@@ -12,8 +11,7 @@ import torch
from PIL import Image
from comfy.utils import common_upscale
-from comfy_api.latest import Input, InputImpl
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import Input, InputImpl, Types
from ._helpers import mimetype_to_extension
@@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio(
image: torch.Tensor,
- name: Optional[str] = None,
+ name: str | None = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
@@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string(
video: Input.Video,
- container_format: VideoContainer = None,
- codec: VideoCodec = None
+ container_format: Types.VideoContainer | None = None,
+ codec: Types.VideoCodec | None = None,
) -> str:
"""
Converts a video input to a base64 string.
@@ -189,12 +187,11 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
-
- # Use provided format/codec if specified, otherwise use video's own if available
- format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
- codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
-
- video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
+ video.save_to(
+ video_bytes_io,
+ format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
+ codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
+ )
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 14207dc68..3e0d0352d 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -3,15 +3,15 @@ import contextlib
import uuid
from io import BytesIO
from pathlib import Path
-from typing import IO, Optional, Union
+from typing import IO
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
-from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
+from comfy_api.latest import InputImpl
from . import request_logger
from ._helpers import (
@@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
- dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
+ dest: BytesIO | IO[bytes] | str | Path | None,
*,
- timeout: Optional[float] = None,
+ timeout: float | None = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
@@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
- session: Optional[aiohttp.ClientSession] = None
- stop_evt: Optional[asyncio.Event] = None
- monitor_task: Optional[asyncio.Task] = None
- req_task: Optional[asyncio.Task] = None
+ session: aiohttp.ClientSession | None = None
+ stop_evt: asyncio.Event | None = None
+ monitor_task: asyncio.Task | None = None
+ req_task: asyncio.Task | None = None
try:
with contextlib.suppress(Exception):
@@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
-) -> VideoFromFile:
+) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
- return VideoFromFile(result)
+ return InputImpl.VideoFromFile(result)
async def download_url_as_bytesio(
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index ac52e2eab..e0cb4428d 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import datetime
import hashlib
import json
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 0532bea9a..b8d33f4d1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -4,15 +4,13 @@ import logging
import time
import uuid
from io import BytesIO
-from typing import Optional
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
-from comfy_api.latest import IO, Input
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
@@ -32,7 +30,7 @@ from .conversions import (
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
- content_type: Optional[str] = Field(
+ content_type: str | None = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
- # if batch, try to upload each file if max_images is greater than 0
+ # if batched, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
@@ -100,9 +98,9 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
- container: VideoContainer = VideoContainer.MP4,
- codec: VideoCodec = VideoCodec.H264,
- max_duration: Optional[int] = None,
+ container: Types.VideoContainer = Types.VideoContainer.MP4,
+ codec: Types.VideoCodec = Types.VideoCodec.H264,
+ max_duration: int | None = None,
wait_label: str | None = "Uploading",
) -> str:
"""
@@ -220,7 +218,7 @@ async def upload_file(
return
monitor_task = asyncio.create_task(_monitor())
- sess: Optional[aiohttp.ClientSession] = None
+ sess: aiohttp.ClientSession | None = None
try:
try:
request_logger.log_request_response(
diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py
index ec7006aed..f01edea96 100644
--- a/comfy_api_nodes/util/validation_utils.py
+++ b/comfy_api_nodes/util/validation_utils.py
@@ -1,9 +1,7 @@
import logging
-from typing import Optional
import torch
-from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions(
image: torch.Tensor,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
height, width = get_image_dimensions(image)
@@ -37,8 +35,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
@@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
- min_rel: float, # e.g. 0.8
- max_rel: float, # e.g. 1.25
+ min_rel: float, # e.g. 0.8
+ max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string(
aspect_ratio: str,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions(
video: Input.Video,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
try:
width, height = video.get_dimensions()
@@ -120,8 +118,8 @@ def validate_video_dimensions(
def validate_video_duration(
video: Input.Video,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
):
try:
duration = video.get_duration()
@@ -136,6 +134,23 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
+def validate_video_frame_count(
+ video: Input.Video,
+ min_frame_count: int | None = None,
+ max_frame_count: int | None = None,
+):
+ try:
+ frame_count = video.get_frame_count()
+ except Exception as e:
+ logging.error("Error getting frame count of video: %s", e)
+ return
+
+ if min_frame_count is not None and min_frame_count > frame_count:
+ raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
+ if max_frame_count is not None and frame_count > max_frame_count:
+ raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
+
+
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
@@ -144,8 +159,8 @@ def get_number_of_images(images):
def validate_audio_duration(
audio: Input.Audio,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
@@ -177,7 +192,7 @@ def validate_string(
)
-def validate_container_format_is_mp4(video: VideoInput) -> None:
+def validate_container_format_is_mp4(video: Input.Video) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds(
ar: float,
*,
- min_ratio: Optional[tuple[float, float]] = None,
- max_ratio: Optional[tuple[float, float]] = None,
+ min_ratio: tuple[float, float] | None = None,
+ max_ratio: tuple[float, float] | None = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
From 35fa091340c60612dfb71cb6822dc23b99a5dac2 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 4 Dec 2025 19:52:09 -0800
Subject: [PATCH 28/81] Forgot to put this in README. (#11112)
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 91fb510e1..ed857df9f 100644
--- a/README.md
+++ b/README.md
@@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
+ - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
From 0ec05b1481d12b299bc945dbd407b773cfb66483 Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Fri, 5 Dec 2025 11:05:38 -0800
Subject: [PATCH 29/81] Remove line made unnecessary (and wrong) after
transformer_options was added to NextDiT's _forward definition (#11118)
---
comfy/ldm/lumina/model.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index f1c1a0ec3..6c24fed9b 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -586,7 +586,6 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
patches = transformer_options.get("patches", {})
- transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
From 43071e3de3780f984a46549e90935a0bf405e9df Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 5 Dec 2025 11:35:42 -0800
Subject: [PATCH 30/81] Make old scaled fp8 format use the new mixed quant ops
system. (#11000)
---
comfy/model_base.py | 14 +-
comfy/model_detection.py | 33 +---
comfy/model_patcher.py | 20 +--
comfy/ops.py | 145 +++++++-----------
comfy/quant_ops.py | 30 ++--
comfy/sd.py | 68 ++++++--
comfy/sd1_clip.py | 22 +--
comfy/supported_models_base.py | 3 +-
comfy/text_encoders/cosmos.py | 12 +-
comfy/text_encoders/flux.py | 12 +-
comfy/text_encoders/genmo.py | 6 +-
comfy/text_encoders/hidream.py | 10 +-
comfy/text_encoders/hunyuan_image.py | 12 +-
comfy/text_encoders/hunyuan_video.py | 23 ++-
comfy/text_encoders/lumina2.py | 6 +-
comfy/text_encoders/omnigen2.py | 6 +-
comfy/text_encoders/ovis.py | 5 +-
comfy/text_encoders/pixart_t5.py | 6 +-
comfy/text_encoders/qwen_image.py | 6 +-
comfy/text_encoders/sd3_clip.py | 19 +--
comfy/text_encoders/wan.py | 6 +-
comfy/text_encoders/z_image.py | 5 +-
comfy/utils.py | 66 ++++++++
.../comfy_quant/test_mixed_precision.py | 18 ++-
24 files changed, 278 insertions(+), 275 deletions(-)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 9b76c285e..3cedd4f31 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
- operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
+ operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -329,18 +329,6 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
-
- if self.model_config.scaled_fp8 is not None:
- unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
-
- # Save mixed precision metadata
- if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
- metadata = {
- "format_version": "1.0",
- "layers": self.model_config.layer_quant_config
- }
- unet_state_dict["_quantization_metadata"] = metadata
-
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 7d0517e61..fd1907627 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -6,20 +6,6 @@ import math
import logging
import torch
-
-def detect_layer_quantization(metadata):
- quant_key = "_quantization_metadata"
- if metadata is not None and quant_key in metadata:
- quant_metadata = metadata.pop(quant_key)
- quant_metadata = json.loads(quant_metadata)
- if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
- logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
- return quant_metadata["layers"]
- else:
- raise ValueError("Invalid quantization metadata format")
- return None
-
-
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
- scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
- if scaled_fp8_key in state_dict:
- scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
- model_config.scaled_fp8 = scaled_fp8_weight.dtype
- if model_config.scaled_fp8 == torch.float32:
- model_config.scaled_fp8 = torch.float8_e4m3fn
- if scaled_fp8_weight.nelement() == 2:
- model_config.optimizations["fp8"] = False
- else:
- model_config.optimizations["fp8"] = True
-
# Detect per-layer quantization (mixed precision)
- layer_quant_config = detect_layer_quantization(metadata)
- if layer_quant_config:
- model_config.layer_quant_config = layer_quant_config
- logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
+ quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
+ if quant_config:
+ model_config.quant_config = quant_config
+ logging.info("Detected mixed precision quantization")
return model_config
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 3dcac3eef..215784874 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -126,27 +126,11 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
- self.convert_func = convert_func
+ self.convert_func = convert_func # TODO: remove
self.set_func = set_func
def __call__(self, weight):
- intermediate_dtype = weight.dtype
- if self.convert_func is not None:
- weight = self.convert_func(weight, inplace=False)
-
- if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
- intermediate_dtype = torch.float32
- out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
- if self.set_func is None:
- return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
- else:
- return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
-
- out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
- if self.set_func is not None:
- return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
- else:
- return out
+ return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
diff --git a/comfy/ops.py b/comfy/ops.py
index eae434e68..dc06709a1 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
+import json
def run_every_op():
if torch.compiler.is_compiling():
@@ -422,22 +423,12 @@ def fp8_linear(self, input):
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
+ scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- scale_weight = self.scale_weight
- scale_input = self.scale_input
- if scale_weight is None:
- scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- else:
- scale_weight = scale_weight.to(input.device)
-
- if scale_input is None:
- scale_input = torch.ones((), device=input.device, dtype=torch.float32)
- input = torch.clamp(input, min=-448, max=448, out=input)
- layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
- quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
- else:
- scale_input = scale_input.to(input.device)
- quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
+ scale_input = torch.ones((), device=input.device, dtype=torch.float32)
+ input = torch.clamp(input, min=-448, max=448, out=input)
+ layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
+ quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@@ -458,7 +449,7 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
- if not self.training:
+ if len(self.weight_function) == 0 and len(self.bias_function) == 0:
try:
out = fp8_linear(self, input)
if out is not None:
@@ -471,59 +462,6 @@ class fp8_ops(manual_cast):
uncast_bias_weight(self, weight, bias, offload_stream)
return x
-def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
- logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
- class scaled_fp8_op(manual_cast):
- class Linear(manual_cast.Linear):
- def __init__(self, *args, **kwargs):
- if override_dtype is not None:
- kwargs['dtype'] = override_dtype
- super().__init__(*args, **kwargs)
-
- def reset_parameters(self):
- if not hasattr(self, 'scale_weight'):
- self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
-
- if not scale_input:
- self.scale_input = None
-
- if not hasattr(self, 'scale_input'):
- self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
- return None
-
- def forward_comfy_cast_weights(self, input):
- if fp8_matrix_mult:
- out = fp8_linear(self, input)
- if out is not None:
- return out
-
- weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
-
- if weight.numel() < input.numel(): #TODO: optimize
- x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
- else:
- x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
- uncast_bias_weight(self, weight, bias, offload_stream)
- return x
-
- def convert_weight(self, weight, inplace=False, **kwargs):
- if inplace:
- weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
- return weight
- else:
- return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
-
- def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
- weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
- if return_weight:
- return weight
- if inplace_update:
- self.weight.data.copy_(weight)
- else:
- self.weight = torch.nn.Parameter(weight, requires_grad=False)
-
- return scaled_fp8_op
-
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
@@ -550,9 +488,9 @@ if CUBLAS_IS_AVAILABLE:
from .quant_ops import QuantizedTensor, QUANT_ALGOS
-def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
+def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
- _layer_quant_config = layer_quant_config
+ _quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
@@ -595,27 +533,36 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
manually_loaded_keys = [weight_key]
- if layer_name not in MixedPrecisionOps._layer_quant_config:
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
- quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
- if quant_format is None:
+ self.quant_format = layer_conf.get("format", None)
+ if not self._full_precision_mm:
+ self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
+
+ if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
- qconfig = QUANT_ALGOS[quant_format]
+ qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
+ scale = state_dict.pop(weight_scale_key, None)
layout_params = {
- 'scale': state_dict.pop(weight_scale_key, None),
+ 'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
- if layout_params['scale'] is not None:
+
+ if scale is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
- QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
+ QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)
@@ -624,7 +571,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
_v = state_dict.pop(param_key, None)
if _v is None:
continue
- setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
+ self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@@ -633,6 +580,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if key in missing_keys:
missing_keys.remove(key)
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
+ if isinstance(self.weight, QuantizedTensor):
+ sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
+ quant_conf = {"format": self.quant_format}
+ if self._full_precision_mm:
+ quant_conf["full_precision_matrix_mult"] = True
+ sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
+ return sd
+
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@@ -648,9 +605,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
- getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
- input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
+ input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
@@ -661,7 +617,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
- weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
+ weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@@ -670,17 +626,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
+ if recurse:
+ for module in self.children():
+ module._apply(fn)
+
+ for key, param in self._parameters.items():
+ if param is None:
+ continue
+ self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
+ for key, buf in self._buffers.items():
+ if buf is not None:
+ self._buffers[key] = fn(buf)
+ return self
+
return MixedPrecisionOps
-def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
+def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
- if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
- logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
- return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
-
- if scaled_fp8 is not None:
- return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
+ if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
+ logging.info("Using mixed precision operations")
+ return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index bb1fb860c..571d3f760 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor):
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
+ def storage(self):
+ return self._qdata.storage()
+
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
@@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
- if target_dtype is not None and target_dtype != qt.dtype:
- logging.warning(
- f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
- f"but not supported for quantized tensors. Ignoring dtype."
- )
-
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
@@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
+ if target_dtype is not None:
+ new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
@@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
+ orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
+ qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
@@ -397,17 +398,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
- if scale is None:
+ if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
- if not isinstance(scale, torch.Tensor):
- scale = torch.tensor(scale)
- scale = scale.to(device=tensor.device, dtype=torch.float32)
+ if scale is not None:
+ if not isinstance(scale, torch.Tensor):
+ scale = torch.tensor(scale)
+ scale = scale.to(device=tensor.device, dtype=torch.float32)
- if inplace_ops:
- tensor *= (1.0 / scale).to(tensor.dtype)
+ if inplace_ops:
+ tensor *= (1.0 / scale).to(tensor.dtype)
+ else:
+ tensor = tensor * (1.0 / scale).to(tensor.dtype)
else:
- tensor = tensor * (1.0 / scale).to(tensor.dtype)
+ scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
diff --git a/comfy/sd.py b/comfy/sd.py
index 03bdb33d5..092715d79 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -968,10 +968,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
- if metadata is not None:
- quant_metadata = metadata.get("_quantization_metadata", None)
- if quant_metadata is not None:
- sd["_quantization_metadata"] = quant_metadata
+ if model_options.get("custom_operations", None) is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@@ -1088,7 +1086,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@@ -1112,7 +1110,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
- clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
@@ -1141,7 +1139,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
- clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
@@ -1169,7 +1167,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@@ -1224,8 +1222,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0
for c in clip_data:
- if "_quantization_metadata" in c:
- c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@@ -1295,6 +1291,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
+
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@@ -1303,18 +1303,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
-
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
- model_config.custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ if model_config.quant_config is not None:
+ manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
+ else:
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@@ -1332,6 +1336,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
+ if te_model_options.get("custom_operations", None) is None:
+ scaled_fp8_list = []
+ for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
+ if k.endswith(".scaled_fp8"):
+ scaled_fp8_list.append(k[:-len("scaled_fp8")])
+
+ if len(scaled_fp8_list) > 0:
+ out_sd = {}
+ for k in sd:
+ skip = False
+ for pref in scaled_fp8_list:
+ skip = skip or k.startswith(pref)
+ if not skip:
+ out_sd[k] = sd[k]
+
+ for pref in scaled_fp8_list:
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ for k in quant_sd:
+ out_sd[k] = quant_sd[k]
+ sd = out_sd
+
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
@@ -1394,6 +1419,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
if len(temp_sd) > 0:
sd = temp_sd
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
@@ -1424,7 +1452,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
if dtype is None:
@@ -1432,12 +1460,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
- if model_config.layer_quant_config is not None:
+ if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
- model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
+
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@@ -1476,6 +1507,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
+ if metadata is None:
+ metadata = {}
+
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 503a51843..962948dae 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
- scaled_fp8 = None
- quantization_metadata = model_options.get("quantization_metadata", None)
+ quant_config = model_options.get("quantization_metadata", None)
if operations is None:
- layer_quant_config = None
- if quantization_metadata is not None:
- layer_quant_config = json.loads(quantization_metadata).get("layers", None)
-
- if layer_quant_config is not None:
- operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
- logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
+ if quant_config is not None:
+ operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
+ logging.info("Using MixedPrecisionOps for text encoder")
else:
- # Fallback to scaled_fp8_ops for backward compatibility
- scaled_fp8 = model_options.get("scaled_fp8", None)
- if scaled_fp8 is not None:
- operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
- else:
- operations = comfy.ops.manual_cast
+ operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
- if scaled_fp8 is not None:
- self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index e4bd74514..9fd84d329 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -49,8 +49,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
- scaled_fp8 = None
- layer_quant_config = None # Per-layer quantization configuration for mixed precision
+ quant_config = None # quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py
index a1adb5242..448381fa9 100644
--- a/comfy/text_encoders/cosmos.py
+++ b/comfy/text_encoders/cosmos.py
@@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
- t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
- if t5xxl_scaled_fp8 is not None:
+ t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
+ if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5xxl_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
@@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def te(dtype_t5=None, t5_quantization_metadata=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py
index 99f4812bb..21d93d757 100644
--- a/comfy/text_encoders/flux.py
+++ b/comfy/text_encoders/flux.py
@@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
-def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
+def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
@@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
-def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
+def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
- model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py
index 9dcf190a2..5daea8135 100644
--- a/comfy/text_encoders/genmo.py
+++ b/comfy/text_encoders/genmo.py
@@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py
index dbcf52784..600b34480 100644
--- a/comfy/text_encoders/hidream.py
+++ b/comfy/text_encoders/hidream.py
@@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
return self.llama.load_sd(sd)
-def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
+def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
- if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["llama_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_
diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py
index ff04726e1..cd198036c 100644
--- a/comfy/text_encoders/hunyuan_image.py
+++ b/comfy/text_encoders/hunyuan_image.py
@@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
- llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
- if llama_scaled_fp8 is not None:
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
else:
return super().load_sd(sd)
-def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
+def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["qwen_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py
index 0110517bb..a9a6c525e 100644
--- a/comfy/text_encoders/hunyuan_video.py
+++ b/comfy/text_encoders/hunyuan_video.py
@@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
import torch
import os
import numbers
-
+import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
@@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
- scaled_fp8_key = "{}scaled_fp8".format(prefix)
- if scaled_fp8_key in state_dict:
- out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
-
- if "_quantization_metadata" in state_dict:
- out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
+ quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
+ if quant is not None:
+ out["llama_quantization_metadata"] = quant
return out
@@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
- llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
- if llama_scaled_fp8 is not None:
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
@@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
return self.llama.load_sd(sd)
-def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
+def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["llama_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_
diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py
index fd986e2c1..7a6cfdab2 100644
--- a/comfy/text_encoders/lumina2.py
+++ b/comfy/text_encoders/lumina2.py
@@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
+def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
@@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py
index 1a01b2dd4..50aa4121f 100644
--- a/comfy/text_encoders/omnigen2.py
+++ b/comfy/text_encoders/omnigen2.py
@@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py
index 81c9bd51c..5754424d2 100644
--- a/comfy/text_encoders/ovis.py
+++ b/comfy/text_encoders/ovis.py
@@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel):
return out, pooled, {}
-def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class OvisTEModel_(OvisTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
- model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py
index 5f383de07..e5e5f18be 100644
--- a/comfy/text_encoders/pixart_t5.py
+++ b/comfy/text_encoders/pixart_t5.py
@@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py
index c0d32a6ef..5c14dec23 100644
--- a/comfy/text_encoders/qwen_image.py
+++ b/comfy/text_encoders/qwen_image.py
@@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
return out, pooled, extra
-def te(dtype_llama=None, llama_scaled_fp8=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py
index ff5d412db..8b153c72b 100644
--- a/comfy/text_encoders/sd3_clip.py
+++ b/comfy/text_encoders/sd3_clip.py
@@ -6,14 +6,15 @@ import torch
import os
import comfy.model_management
import logging
+import comfy.utils
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
- t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
- if t5xxl_scaled_fp8 is not None:
+ t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
+ if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
- scaled_fp8_key = "{}scaled_fp8".format(prefix)
- if scaled_fp8_key in state_dict:
- out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
+ quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
+ if quant is not None:
+ out["t5_quantization_metadata"] = quant
return out
@@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
-def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
+def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_
diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py
index d50fa4b28..164a57edd 100644
--- a/comfy/text_encoders/wan.py
+++ b/comfy/text_encoders/wan.py
@@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
-def te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def te(dtype_t5=None, t5_quantization_metadata=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py
index bb9273b20..19adde0b7 100644
--- a/comfy/text_encoders/z_image.py
+++ b/comfy/text_encoders/z_image.py
@@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
- model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
diff --git a/comfy/utils.py b/comfy/utils.py
index 37485e497..89846bc95 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -29,6 +29,7 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
+import json
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes):
else:
output_tensors = combined_latent
return output_tensors
+
+def detect_layer_quantization(state_dict, prefix):
+ for k in state_dict:
+ if k.startswith(prefix) and k.endswith(".comfy_quant"):
+ logging.info("Found quantization metadata version 1")
+ return {"mixed_ops": True}
+ return None
+
+def convert_old_quants(state_dict, model_prefix="", metadata={}):
+ if metadata is None:
+ metadata = {}
+
+ quant_metadata = None
+ if "_quantization_metadata" not in metadata:
+ scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
+
+ if scaled_fp8_key in state_dict:
+ scaled_fp8_weight = state_dict[scaled_fp8_key]
+ scaled_fp8_dtype = scaled_fp8_weight.dtype
+ if scaled_fp8_dtype == torch.float32:
+ scaled_fp8_dtype = torch.float8_e4m3fn
+
+ if scaled_fp8_weight.nelement() == 2:
+ full_precision_matrix_mult = True
+ else:
+ full_precision_matrix_mult = False
+
+ out_sd = {}
+ layers = {}
+ for k in list(state_dict.keys()):
+ if not k.startswith(model_prefix):
+ out_sd[k] = state_dict[k]
+ continue
+ k_out = k
+ w = state_dict.pop(k)
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
+
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
+
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ if w.item() == 1.0:
+ continue
+
+ out_sd[k_out] = w
+
+ state_dict = out_sd
+ quant_metadata = {"layers": layers}
+ else:
+ quant_metadata = json.loads(metadata["_quantization_metadata"])
+
+ if quant_metadata is not None:
+ layers = quant_metadata["layers"]
+ for k, v in layers.items():
+ state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
+
+ return state_dict, metadata
diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py
index 63361309f..3a54941e6 100644
--- a/tests-unit/comfy_quant/test_mixed_precision.py
+++ b/tests-unit/comfy_quant/test_mixed_precision.py
@@ -2,6 +2,7 @@ import unittest
import torch
import sys
import os
+import json
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -15,6 +16,7 @@ if not has_gpu():
from comfy import ops
from comfy.quant_ops import QuantizedTensor
+import comfy.utils
class SimpleModel(torch.nn.Module):
@@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
# Create model and load state dict (strict=False because custom loading pops keys)
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
@@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
- output = model(input_tensor)
+ with torch.inference_mode():
+ output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
@@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict1, strict=False)
# Save state dict
@@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
@@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
From 6fd463aec958f02be79a264eafd6c8fe7e52762a Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 5 Dec 2025 12:33:16 -0800
Subject: [PATCH 31/81] Fix regression when text encoder loaded directly on
GPU. (#11129)
---
comfy/ops.py | 2 ++
comfy/sd.py | 44 ++++++++++++++++++++++++--------------------
2 files changed, 26 insertions(+), 20 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index dc06709a1..35237c9f7 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -552,6 +552,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
+ if scale is not None:
+ scale = scale.to(device)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
diff --git a/comfy/sd.py b/comfy/sd.py
index 092715d79..c350322f8 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -98,7 +98,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
- def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
+ def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:
return
params = target.params.copy()
@@ -129,6 +129,27 @@ class CLIP:
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
+ if len(state_dict) > 0:
+ if isinstance(state_dict, list):
+ for c in state_dict:
+ m, u = self.load_sd(c)
+ if len(m) > 0:
+ logging.warning("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected: {}".format(u))
+ else:
+ m, u = self.load_sd(state_dict, full_model=True)
+ if len(m) > 0:
+ m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
+ if len(m_filter) > 0:
+ logging.warning("clip missing: {}".format(m))
+ else:
+ logging.debug("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected {}:".format(u))
+
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@@ -1225,14 +1246,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
- for c in clip_data:
- m, u = clip.load_sd(c)
- if len(m) > 0:
- logging.warning("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected: {}".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
return clip
def load_gligen(ckpt_path):
@@ -1362,17 +1376,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
- m, u = clip.load_sd(clip_sd, full_model=True)
- if len(m) > 0:
- m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
- if len(m_filter) > 0:
- logging.warning("clip missing: {}".format(m))
- else:
- logging.debug("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected {}:".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
From 79d17ba2339aaf4f3422673b3dad24ba4dbd7552 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Fri, 5 Dec 2025 22:42:46 +0200
Subject: [PATCH 32/81] Context windows fixes and features (#10975)
* Apply cond slice fix
* Add FreeNoise
* Update context_windows.py
* Add option to retain condition by indexes for each window
This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations.
* Update context_windows.py
* Allow splitting multiple conds into different windows
* Add handling for audio_embed
* whitespace
* Allow freenoise to work on other dims, handle 4D batch timestep
Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now.
* Disable experimental options for now
So that the Freenoise and bugfixes can be merged first
---------
Co-authored-by: Jedrzej Kosinski
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
---
comfy/context_windows.py | 104 ++++++++++++++++++++++----
comfy_extras/nodes_context_windows.py | 22 +++++-
2 files changed, 108 insertions(+), 18 deletions(-)
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
index 041f380f9..5c412d1c2 100644
--- a/comfy/context_windows.py
+++ b/comfy/context_windows.py
@@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
- def __init__(self, index_list: list[int], dim: int=0):
+ def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
+ self.total_frames = total_frames
+ self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
- def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
+ def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
- idx = [slice(None)] * dim + [self.index_list]
- return full[idx].to(device)
+ idx = tuple([slice(None)] * dim + [self.index_list])
+ window = full[idx]
+ if retain_index_list:
+ idx = tuple([slice(None)] * dim + [retain_index_list])
+ window[idx] = full[idx]
+ return window.to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
- idx = [slice(None)] * dim + [self.index_list]
+ idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add
return full
+ def get_region_index(self, num_regions: int) -> int:
+ region_idx = int(self.center_ratio * num_regions)
+ return min(max(region_idx, 0), num_regions - 1)
+
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@@ -94,7 +104,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
- def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
+ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
+ closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
+ self.freenoise = freenoise
+ self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
+ self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
- logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
+ logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
+ if self.cond_retain_index_list:
+ logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True
return False
@@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None
# reuse or resize cond items to match context requirements
resized_cond = []
+ # if multiple conds, split based on primary region
+ if self.split_conds_to_windows and len(cond_in) > 1:
+ region = window.get_region_index(len(cond_in))
+ logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
+ cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
@@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
- if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
+ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
+ # Handle audio_embed (temporal dim is 1)
+ elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
+ audio_cond = cond_value.cond
+ if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
- if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
- new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
+ if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
@@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
- mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
+ mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
- context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
+ context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
- idx_window = [slice(None)] * self.dim + [idx]
- pos_window = [slice(None)] * self.dim + [pos]
+ idx_window = tuple([slice(None)] * self.dim + [idx])
+ pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
@@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)
+def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
+ model_options = extra_args.get("model_options", None)
+ if model_options is None:
+ raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ handler: IndexListContextHandler = model_options.get("context_handler", None)
+ if handler is None:
+ raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ if not handler.freenoise:
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+ noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
+
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+
+
+def create_sampler_sample_wrapper(model: ModelPatcher):
+ model.add_wrapper_with_key(
+ comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
+ "ContextWindows_sampler_sample",
+ _sampler_sample_wrapper
+ )
+
+
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
+
+
+# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
+def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
+ logging.info("Context windows: Applying FreeNoise")
+ generator = torch.Generator(device='cpu').manual_seed(seed)
+ latent_video_length = noise.shape[dim]
+ delta = context_length - context_overlap
+
+ for start_idx in range(0, latent_video_length - context_length, delta):
+ place_idx = start_idx + context_length
+
+ actual_delta = min(delta, latent_video_length - place_idx)
+ if actual_delta <= 0:
+ break
+
+ list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
+
+ source_slice = [slice(None)] * noise.ndim
+ source_slice[dim] = list_idx
+ target_slice = [slice(None)] * noise.ndim
+ target_slice[dim] = slice(place_idx, place_idx + actual_delta)
+
+ noise[tuple(target_slice)] = noise[tuple(source_slice)]
+
+ return noise
diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py
index 1c3d9e697..3799a9004 100644
--- a/comfy_extras/nodes_context_windows.py
+++ b/comfy_extras/nodes_context_windows.py
@@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
+ #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
+ #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
)
@classmethod
- def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
- dim=dim)
+ dim=dim,
+ freenoise=freenoise,
+ cond_retain_index_list=cond_retain_index_list,
+ split_conds_to_windows=split_conds_to_windows
+ )
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
+ if freenoise: # no other use for this wrapper at this time
+ comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode):
@@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
+ #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
+ #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
]
return schema
@classmethod
- def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
- return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
+ return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension):
From 092ee8a5008c8d558b0a72cc7961a31d9cc5400b Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 5 Dec 2025 15:25:31 -0800
Subject: [PATCH 33/81] Fix some custom nodes. (#11134)
---
comfy/supported_models_base.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 9fd84d329..0e7a829ba 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -17,6 +17,7 @@
"""
import torch
+import logging
from . import model_base
from . import utils
from . import latent_formats
@@ -117,3 +118,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
+
+ def __getattr__(self, name):
+ logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
+ return None
From bed12674a1d2c4bfdfbdd098686390f807996c90 Mon Sep 17 00:00:00 2001
From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com>
Date: Sat, 6 Dec 2025 08:45:38 +0900
Subject: [PATCH 34/81] docs: add ComfyUI-Manager documentation and update to
v4.0.3b4 (#11133)
- Add manager setup instructions and command line options to README
- Document --enable-manager, --enable-manager-legacy-ui, and
--disable-manager-ui flags
- Bump comfyui_manager version from 4.0.3b3 to 4.0.3b4
---
README.md | 26 ++++++++++++++++++++++++++
manager_requirements.txt | 2 +-
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ed857df9f..bae955b1b 100644
--- a/README.md
+++ b/README.md
@@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
+
+## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
+
+**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
+
+### Setup
+
+1. Install the manager dependencies:
+ ```bash
+ pip install -r manager_requirements.txt
+ ```
+
+2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
+ ```bash
+ python main.py --enable-manager
+ ```
+
+### Command Line Options
+
+| Flag | Description |
+|------|-------------|
+| `--enable-manager` | Enable ComfyUI-Manager |
+| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
+| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
+
+
# Running
```python main.py```
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 52cc5389c..b95cefb74 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.0.3b3
+comfyui_manager==4.0.3b4
From fd109325db7126f92c2dfb7e6b25310eded8c1f8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Sat, 6 Dec 2025 05:20:22 +0200
Subject: [PATCH 35/81] Kandinsky5 model support (#10988)
* Add Kandinsky5 model support
lite and pro T2V tested to work
* Update kandinsky5.py
* Fix fp8
* Fix fp8_scaled text encoder
* Add transformer_options for attention
* Code cleanup, optimizations, use fp32 for all layers originally at fp32
* ImageToVideo -node
* Fix I2V, add necessary latent post process nodes
* Support text to image model
* Support block replace patches (SLG mostly)
* Support official LoRAs
* Don't scale RoPE for lite model as that just doesn't work...
* Update supported_models.py
* Rever RoPE scaling to simpler one
* Fix typo
* Handle latent dim difference for image model in the VAE instead
* Add node to use different prompts for clip_l and qwen25_7b
* Reduce peak VRAM usage a bit
* Further reduce peak VRAM consumption by chunking ffn
* Update chunking
* Update memory_usage_factor
* Code cleanup, don't force the fp32 layers as it has minimal effect
* Allow for stronger changes with first frames normalization
Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available.
* Add image model's own chat template, remove unused image2video template
* Remove hard error in ReplaceVideoLatentFrames -node
* Update kandinsky5.py
* Update supported_models.py
* Fix typos in prompt template
They were now fixed in the original repository as well
* Update ReplaceVideoLatentFrames
Add tooltips
Make source optional
Better handle negative index
* Rename NormalizeVideoLatentFrames -node
For bit better clarity what it does
* Fix NormalizeVideoLatentStart node out on non-op
---
comfy/ldm/kandinsky5/model.py | 407 ++++++++++++++++++++++++++++++
comfy/lora.py | 7 +
comfy/model_base.py | 47 ++++
comfy/model_detection.py | 18 ++
comfy/sd.py | 11 +
comfy/supported_models.py | 56 +++-
comfy/text_encoders/kandinsky5.py | 68 +++++
comfy_api/latest/_io.py | 2 +
comfy_extras/nodes_kandinsky5.py | 136 ++++++++++
comfy_extras/nodes_latent.py | 39 ++-
nodes.py | 3 +-
11 files changed, 791 insertions(+), 3 deletions(-)
create mode 100644 comfy/ldm/kandinsky5/model.py
create mode 100644 comfy/text_encoders/kandinsky5.py
create mode 100644 comfy_extras/nodes_kandinsky5.py
diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py
new file mode 100644
index 000000000..a653e02fc
--- /dev/null
+++ b/comfy/ldm/kandinsky5/model.py
@@ -0,0 +1,407 @@
+import torch
+from torch import nn
+import math
+
+import comfy.ldm.common_dit
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope1
+from comfy.ldm.flux.layers import EmbedND
+
+def attention(q, k, v, heads, transformer_options={}):
+ return optimized_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ heads=heads,
+ skip_reshape=True,
+ transformer_options=transformer_options
+ )
+
+def apply_scale_shift_norm(norm, x, scale, shift):
+ return torch.addcmul(shift, norm(x), scale + 1.0)
+
+def apply_gate_sum(x, out, gate):
+ return torch.addcmul(x, gate, out)
+
+def get_shift_scale_gate(params):
+ shift, scale, gate = torch.chunk(params, 3, dim=-1)
+ return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
+
+def get_freqs(dim, max_period=10000.0):
+ return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
+
+
+class TimeEmbeddings(nn.Module):
+ def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
+ super().__init__()
+ assert model_dim % 2 == 0
+ self.model_dim = model_dim
+ self.max_period = max_period
+ self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.SiLU()
+ self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, timestep, dtype):
+ args = torch.outer(timestep, self.freqs.to(device=timestep.device))
+ time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
+ return time_embed
+
+
+class TextEmbeddings(nn.Module):
+ def __init__(self, text_dim, model_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, text_embed):
+ text_embed = self.in_layer(text_embed)
+ return self.norm(text_embed).type_as(text_embed)
+
+
+class VisualEmbeddings(nn.Module):
+ def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ x = x.movedim(1, -1) # B C T H W -> B T H W C
+ B, T, H, W, dim = x.shape
+ pt, ph, pw = self.patch_size
+
+ x = x.view(
+ B,
+ T // pt, pt,
+ H // ph, ph,
+ W // pw, pw,
+ dim,
+ ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
+
+ return self.in_layer(x)
+
+
+class Modulation(nn.Module):
+ def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
+ super().__init__()
+ self.activation = nn.SiLU()
+ self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ return self.out_layer(self.activation(x))
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, num_channels, head_dim, operation_settings=None):
+ super().__init__()
+ assert num_channels % head_dim == 0
+ self.num_heads = num_channels // head_dim
+ self.head_dim = head_dim
+
+ operations = operation_settings.get("operations")
+ self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 2
+
+ def _compute_qk(self, x, freqs, proj_fn, norm_fn):
+ result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
+ return apply_rope1(norm_fn(result), freqs)
+
+ def _forward(self, x, freqs, transformer_options={}):
+ q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
+ k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def _forward_chunked(self, x, freqs, transformer_options={}):
+ def process_chunks(proj_fn, norm_fn):
+ x_chunks = torch.chunk(x, self.num_chunks, dim=1)
+ freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
+ chunks = []
+ for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
+ chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
+ return torch.cat(chunks, dim=1)
+
+ q = process_chunks(self.to_query, self.query_norm)
+ k = process_chunks(self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def forward(self, x, freqs, transformer_options={}):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x, freqs, transformer_options=transformer_options)
+ else:
+ return self._forward(x, freqs, transformer_options=transformer_options)
+
+
+class CrossAttention(SelfAttention):
+ def get_qkv(self, x, context):
+ q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
+ k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
+ v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
+ return q, k, v
+
+ def forward(self, x, context, transformer_options={}):
+ q, k, v = self.get_qkv(x, context)
+ out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, ff_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.GELU()
+ self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 4
+
+ def _forward(self, x):
+ return self.out_layer(self.activation(self.in_layer(x)))
+
+ def _forward_chunked(self, x):
+ chunks = torch.chunk(x, self.num_chunks, dim=1)
+ output_chunks = []
+ for chunk in chunks:
+ output_chunks.append(self._forward(chunk))
+ return torch.cat(output_chunks, dim=1)
+
+ def forward(self, x):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x)
+ else:
+ return self._forward(x)
+
+
+class OutLayer(nn.Module):
+ def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, visual_embed, time_embed):
+ B, T, H, W, _ = visual_embed.shape
+ shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
+ scale = scale[:, None, None, None, :]
+ shift = shift[:, None, None, None, :]
+ visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
+ x = self.out_layer(visual_embed)
+
+ out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
+ x = x.view(
+ B, T, H, W,
+ out_dim,
+ self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ )
+ return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
+
+
+class TransformerEncoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, x, time_embed, freqs, transformer_options={}):
+ self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
+ out = self.self_attention(out, freqs, transformer_options=transformer_options)
+ x = apply_gate_sum(x, out, gate)
+
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
+ out = self.feed_forward(out)
+ x = apply_gate_sum(x, out, gate)
+ return x
+
+
+class TransformerDecoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
+
+ operations = operation_settings.get("operations")
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
+ self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
+ # self attention
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
+ visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # cross attention
+ shift, scale, gate = get_shift_scale_gate(cross_attn_params)
+ visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
+ visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # feed forward
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
+ visual_out = self.feed_forward(visual_out)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ return visual_embed
+
+
+class Kandinsky5(nn.Module):
+ def __init__(
+ self,
+ in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
+ model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
+ axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
+ dtype=None, device=None, operations=None, **kwargs
+ ):
+ super().__init__()
+ head_dim = sum(axes_dims)
+ self.rope_scale_factor = rope_scale_factor
+ self.in_visual_dim = in_visual_dim
+ self.model_dim = model_dim
+ self.patch_size = patch_size
+ self.visual_embed_dim = visual_embed_dim
+ self.dtype = dtype
+ self.device = device
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
+ self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
+ self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
+ self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
+
+ self.text_transformer_blocks = nn.ModuleList(
+ [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
+ )
+
+ self.visual_transformer_blocks = nn.ModuleList(
+ [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
+ )
+
+ self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
+
+ self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
+ self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
+
+ def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
+ steps = seq_len if steps is None else steps
+ seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
+ seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
+ freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
+ return freqs
+
+ def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
+
+ patch_size = self.patch_size
+ t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
+ h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
+ w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+
+ if steps_t is None:
+ steps_t = t_len
+ if steps_h is None:
+ steps_h = h_len
+ if steps_w is None:
+ steps_w = w_len
+
+ h_start = 0
+ w_start = 0
+ rope_options = transformer_options.get("rope_options", None)
+ if rope_options is not None:
+ t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
+ h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
+ w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
+
+ t_start += rope_options.get("shift_t", 0.0)
+ h_start += rope_options.get("shift_y", 0.0)
+ w_start += rope_options.get("shift_x", 0.0)
+ else:
+ rope_scale_factor = self.rope_scale_factor
+ if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
+ if h * w >= 14080:
+ rope_scale_factor = (1.0, 3.16, 3.16)
+
+ t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
+ h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
+ w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
+
+ img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
+ img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
+
+ freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
+ return freqs
+
+ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
+ patches_replace = transformer_options.get("patches_replace", {})
+ context = self.text_embeddings(context)
+ time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
+
+ for block in self.text_transformer_blocks:
+ context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
+
+ visual_embed = self.visual_embeddings(x)
+ visual_shape = visual_embed.shape[:-1]
+ visual_embed = visual_embed.flatten(1, -2)
+
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.visual_transformer_blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
+ visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
+ else:
+ visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
+
+ visual_embed = visual_embed.reshape(*visual_shape, -1)
+ return self.out_layer(visual_embed, time_embed)
+
+ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ bs, c, t_len, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+
+ if time_dim_replace is not None:
+ time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
+ x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
+
+ freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+ freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+
+ return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
+
+ def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
diff --git a/comfy/lora.py b/comfy/lora.py
index 3a9077869..e7202ce97 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}):
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
+ if isinstance(model, comfy.model_base.Kandinsky5):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+ key_map["transformer.{}".format(key_lora)] = k
+
return key_map
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 3cedd4f31..0be006cc2 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
+import comfy.ldm.kandinsky5.model
import comfy.model_management
import comfy.patcher_extension
@@ -1630,3 +1631,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out
+
+class Kandinsky5(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
+
+ def encode_adm(self, **kwargs):
+ return kwargs["pooled_output"]
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ device = kwargs["device"]
+ image = torch.zeros_like(noise)
+
+ mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if mask is None:
+ mask = torch.zeros_like(noise)[:, :1]
+ else:
+ mask = 1.0 - mask
+ mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if mask.shape[-3] < noise.shape[-3]:
+ mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
+ mask = utils.resize_to_batch_size(mask, noise.shape[0])
+
+ return torch.cat((image, mask), dim=1)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ time_dim_replace = kwargs.get("time_dim_replace", None)
+ if time_dim_replace is not None:
+ out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
+
+ return out
+
+class Kandinsky5Image(Kandinsky5):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+
+ def concat_cond(self, **kwargs):
+ return None
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index fd1907627..30b33a486 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -611,6 +611,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
+ if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
+ dit_config = {}
+ model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["model_dim"] = model_dim
+ if model_dim in [4096, 2560]: # pro video and lite image
+ dit_config["axes_dims"] = (32, 48, 48)
+ if model_dim == 2560: # lite image
+ dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
+ elif model_dim == 1792: # lite video
+ dit_config["axes_dims"] = (16, 24, 24)
+ dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["image_model"] = "kandinsky5"
+ dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
+ dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
+ dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
+ dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
diff --git a/comfy/sd.py b/comfy/sd.py
index c350322f8..754b1703d 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
+import comfy.text_encoders.kandinsky5
import comfy.model_patcher
import comfy.lora
@@ -766,6 +767,8 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
+ if self.latent_dim == 2 and samples_in.ndim == 5:
+ samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -983,6 +986,8 @@ class CLIPType(Enum):
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
OVIS = 21
+ KANDINSKY5 = 22
+ KANDINSKY5_IMAGE = 23
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -1231,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5_IMAGE:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index afd97160b..91cc4ef08 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
+import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
from . import supported_models_base
@@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
+class Kandinsky5(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "kandinsky5",
+ }
+
+ sampling_settings = {
+ "shift": 10.0,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.HunyuanVideo
+
+ memory_usage_factor = 1.1 #TODO
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+class Kandinsky5Image(Kandinsky5):
+ unet_config = {
+ "image_model": "kandinsky5",
+ "model_dim": 2560,
+ "visual_embed_dim": 64,
+ }
+
+ sampling_settings = {
+ "shift": 3.0,
+ }
+
+ latent_format = latent_formats.Flux
+ memory_usage_factor = 1.1 #TODO
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5Image(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py
new file mode 100644
index 000000000..22f991c36
--- /dev/null
+++ b/comfy/text_encoders/kandinsky5.py
@@ -0,0 +1,68 @@
+from comfy import sd1_clip
+from .qwen_image import QwenImageTokenizer, QwenImageTEModel
+from .llama import Qwen25_7BVLI
+
+
+class Kandinsky5Tokenizer(QwenImageTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
+ out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
+
+ return out
+
+
+class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+
+
+class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
+ llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
+ if llama_scaled_fp8 is not None:
+ model_options = model_options.copy()
+ model_options["scaled_fp8"] = llama_scaled_fp8
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+
+class Kandinsky5TEModel(QwenImageTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs):
+ cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
+ l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
+
+ return cond, l_pooled, extra
+
+ def set_clip_options(self, options):
+ super().set_clip_options(options)
+ self.clip_l.set_clip_options(options)
+
+ def reset_clip_options(self):
+ super().reset_clip_options()
+ self.clip_l.reset_clip_options()
+
+ def load_sd(self, sd):
+ if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
+ return self.clip_l.load_sd(sd)
+ else:
+ return super().load_sd(sd)
+
+def te(dtype_llama=None, llama_scaled_fp8=None):
+ class Kandinsky5TEModel_(Kandinsky5TEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["qwen_scaled_fp8"] = llama_scaled_fp8
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return Kandinsky5TEModel_
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 866c3e0eb..d7cbe68cf 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
+ time_dim_replace: NotRequired[torch.Tensor]
+ '''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py
new file mode 100644
index 000000000..9cb234be1
--- /dev/null
+++ b/comfy_extras/nodes_kandinsky5.py
@@ -0,0 +1,136 @@
+import nodes
+import node_helpers
+import torch
+import comfy.model_management
+import comfy.utils
+
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+
+
+class Kandinsky5ImageToVideo(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Kandinsky5ImageToVideo",
+ category="conditioning/video_models",
+ inputs=[
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Vae.Input("vae"),
+ io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
+ io.Image.Input("start_image", optional=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(display_name="positive"),
+ io.Conditioning.Output(display_name="negative"),
+ io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
+ io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+ cond_latent_out = {}
+ if start_image is not None:
+ start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+ encoded = vae.encode(start_image[:, :, :, :3])
+ cond_latent_out["samples"] = encoded
+
+ mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
+
+ positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
+ negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
+
+ out_latent = {}
+ out_latent["samples"] = latent
+ return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
+
+
+def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
+ source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
+ source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
+
+ reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
+ reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
+
+ # normalization
+ normalized = (source - source_mean) / (source_std + 1e-8)
+ normalized = normalized * reference_std + reference_mean
+
+ return normalized
+
+
+class NormalizeVideoLatentStart(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="NormalizeVideoLatentStart",
+ category="conditioning/video_models",
+ description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
+ inputs=[
+ io.Latent.Input("latent"),
+ io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
+ io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
+ ],
+ outputs=[
+ io.Latent.Output(display_name="latent"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
+ if latent["samples"].shape[2] <= 1:
+ return io.NodeOutput(latent)
+ s = latent.copy()
+ samples = latent["samples"].clone()
+
+ first_frames = samples[:, :, :start_frame_count]
+ reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
+ normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
+
+ samples[:, :, :start_frame_count] = normalized_first_frames
+ s["samples"] = samples
+ return io.NodeOutput(s)
+
+
+class CLIPTextEncodeKandinsky5(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="CLIPTextEncodeKandinsky5",
+ category="advanced/conditioning/kandinsky5",
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
+ io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
+ tokens = clip.tokenize(clip_l)
+ tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
+
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
+
+
+class Kandinsky5Extension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ Kandinsky5ImageToVideo,
+ NormalizeVideoLatentStart,
+ CLIPTextEncodeKandinsky5,
+ ]
+
+async def comfy_entrypoint() -> Kandinsky5Extension:
+ return Kandinsky5Extension()
diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py
index d2df07ff9..e439b18ef 100644
--- a/comfy_extras/nodes_latent.py
+++ b/comfy_extras/nodes_latent.py
@@ -4,7 +4,7 @@ import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
-
+import logging
def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]:
@@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
return luminance * sharpened
return io.NodeOutput(sharpen)
+class ReplaceVideoLatentFrames(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ReplaceVideoLatentFrames",
+ category="latent/batch",
+ inputs=[
+ io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
+ io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
+ io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
+ ],
+ outputs=[
+ io.Latent.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, destination, index, source=None) -> io.NodeOutput:
+ if source is None:
+ return io.NodeOutput(destination)
+ dest_frames = destination["samples"].shape[2]
+ source_frames = source["samples"].shape[2]
+ if index < 0:
+ index = dest_frames + index
+ if index > dest_frames:
+ logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
+ return io.NodeOutput(destination)
+ if index + source_frames > dest_frames:
+ logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
+ return io.NodeOutput(destination)
+ s = source.copy()
+ s_source = source["samples"]
+ s_destination = destination["samples"].clone()
+ s_destination[:, :, index:index + s_source.shape[2]] = s_source
+ s["samples"] = s_destination
+ return io.NodeOutput(s)
class LatentExtension(ComfyExtension):
@override
@@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
+ ReplaceVideoLatentFrames
]
diff --git a/nodes.py b/nodes.py
index 356aa63df..8d28a725d 100644
--- a/nodes.py
+++ b/nodes.py
@@ -970,7 +970,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes():
"nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py",
+ "nodes_kandinsky5.py",
]
import_failed = []
From ae676ed105663bb225153c8bca406f00edf738b4 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 5 Dec 2025 20:01:19 -0800
Subject: [PATCH 36/81] Fix regression. (#11137)
---
comfy/supported_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 91cc4ef08..383c82c3e 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1529,6 +1529,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5]
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]
From 117bf3f2bd9235cb5942a1de10a534c9869c7444 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Sat, 6 Dec 2025 06:22:02 +0200
Subject: [PATCH 37/81] convert nodes_freelunch.py to the V3 schema (#10904)
---
comfy_extras/nodes_freelunch.py | 89 +++++++++++++++++----------
comfy_extras/nodes_model_downscale.py | 5 --
2 files changed, 57 insertions(+), 37 deletions(-)
diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py
index e3ac58447..3429b731e 100644
--- a/comfy_extras/nodes_freelunch.py
+++ b/comfy_extras/nodes_freelunch.py
@@ -2,6 +2,8 @@
import torch
import logging
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO
def Fourier_filter(x, threshold, scale):
# FFT
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
return x_filtered.to(x.dtype)
-class FreeU:
+class FreeU(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
- "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- }}
- RETURN_TYPES = ("MODEL",)
- FUNCTION = "patch"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FreeU",
+ category="model_patches/unet",
+ inputs=[
+ IO.Model.Input("model"),
+ IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
+ ],
+ outputs=[
+ IO.Model.Output(),
+ ],
+ )
- CATEGORY = "model_patches/unet"
-
- def patch(self, model, b1, b2, s1, s2):
+ @classmethod
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -59,23 +66,31 @@ class FreeU:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
- return (m, )
+ return IO.NodeOutput(m)
-class FreeU_V2:
+ patch = execute # TODO: remove
+
+
+class FreeU_V2(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
- "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- }}
- RETURN_TYPES = ("MODEL",)
- FUNCTION = "patch"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FreeU_V2",
+ category="model_patches/unet",
+ inputs=[
+ IO.Model.Input("model"),
+ IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
+ ],
+ outputs=[
+ IO.Model.Output(),
+ ],
+ )
- CATEGORY = "model_patches/unet"
-
- def patch(self, model, b1, b2, s1, s2):
+ @classmethod
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -105,9 +120,19 @@ class FreeU_V2:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
- return (m, )
+ return IO.NodeOutput(m)
-NODE_CLASS_MAPPINGS = {
- "FreeU": FreeU,
- "FreeU_V2": FreeU_V2,
-}
+ patch = execute # TODO: remove
+
+
+class FreelunchExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ FreeU,
+ FreeU_V2,
+ ]
+
+
+async def comfy_entrypoint() -> FreelunchExtension:
+ return FreelunchExtension()
diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py
index f7ca9699d..dec2ae841 100644
--- a/comfy_extras/nodes_model_downscale.py
+++ b/comfy_extras/nodes_model_downscale.py
@@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
return io.NodeOutput(m)
-NODE_DISPLAY_NAME_MAPPINGS = {
- # Sampling
- "PatchModelAddDownscale": "",
-}
-
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
From 913f86b72740f84f759786a698108840a09b6498 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Sat, 6 Dec 2025 06:24:10 +0200
Subject: [PATCH 38/81] [V3] convert nodes_mask.py to V3 schema (#10669)
* convert nodes_mask.py to V3 schema
* set "Preview Mask" as display name for MaskPreview
---
comfy_extras/nodes_mask.py | 508 +++++++++++++++++++------------------
1 file changed, 263 insertions(+), 245 deletions(-)
diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py
index a5e405008..290e6f55e 100644
--- a/comfy_extras/nodes_mask.py
+++ b/comfy_extras/nodes_mask.py
@@ -3,11 +3,10 @@ import scipy.ndimage
import torch
import comfy.utils
import node_helpers
-import folder_paths
-import random
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO, UI
import nodes
-from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
@@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
-class LatentCompositeMasked:
+class LatentCompositeMasked(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "destination": ("LATENT",),
- "source": ("LATENT",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
- "resize_source": ("BOOLEAN", {"default": False}),
- },
- "optional": {
- "mask": ("MASK",),
- }
- }
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "composite"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="LatentCompositeMasked",
+ category="latent",
+ inputs=[
+ IO.Latent.Input("destination"),
+ IO.Latent.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
+ IO.Boolean.Input("resize_source", default=False),
+ IO.Mask.Input("mask", optional=True),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
- CATEGORY = "latent"
-
- def composite(self, destination, source, x, y, resize_source, mask = None):
+ @classmethod
+ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
- return (output,)
+ return IO.NodeOutput(output)
-class ImageCompositeMasked:
+ composite = execute # TODO: remove
+
+
+class ImageCompositeMasked(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "destination": ("IMAGE",),
- "source": ("IMAGE",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "resize_source": ("BOOLEAN", {"default": False}),
- },
- "optional": {
- "mask": ("MASK",),
- }
- }
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "composite"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageCompositeMasked",
+ category="image",
+ inputs=[
+ IO.Image.Input("destination"),
+ IO.Image.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Boolean.Input("resize_source", default=False),
+ IO.Mask.Input("mask", optional=True),
+ ],
+ outputs=[IO.Image.Output()],
+ )
- CATEGORY = "image"
-
- def composite(self, destination, source, x, y, resize_source, mask = None):
+ @classmethod
+ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
- return (output,)
+ return IO.NodeOutput(output)
-class MaskToImage:
+ composite = execute # TODO: remove
+
+
+class MaskToImage(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "mask": ("MASK",),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskToImage",
+ display_name="Convert Mask to Image",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ outputs=[IO.Image.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "mask_to_image"
-
- def mask_to_image(self, mask):
+ @classmethod
+ def execute(cls, mask) -> IO.NodeOutput:
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
- return (result,)
+ return IO.NodeOutput(result)
-class ImageToMask:
+ mask_to_image = execute # TODO: remove
+
+
+class ImageToMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "image": ("IMAGE",),
- "channel": (["red", "green", "blue", "alpha"],),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageToMask",
+ display_name="Convert Image to Mask",
+ category="mask",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, image, channel):
+ @classmethod
+ def execute(cls, image, channel) -> IO.NodeOutput:
channels = ["red", "green", "blue", "alpha"]
mask = image[:, :, :, channels.index(channel)]
- return (mask,)
+ return IO.NodeOutput(mask)
-class ImageColorToMask:
+ image_to_mask = execute # TODO: remove
+
+
+class ImageColorToMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "image": ("IMAGE",),
- "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageColorToMask",
+ category="mask",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, image, color):
+ @classmethod
+ def execute(cls, image, color) -> IO.NodeOutput:
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 1.0, 0).float()
- return (mask,)
+ return IO.NodeOutput(mask)
-class SolidMask:
+ image_to_mask = execute # TODO: remove
+
+
+class SolidMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SolidMask",
+ category="mask",
+ inputs=[
+ IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
+ IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "solid"
-
- def solid(self, value, width, height):
+ @classmethod
+ def execute(cls, value, width, height) -> IO.NodeOutput:
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
- return (out,)
+ return IO.NodeOutput(out)
-class InvertMask:
+ solid = execute # TODO: remove
+
+
+class InvertMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="InvertMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "invert"
-
- def invert(self, mask):
+ @classmethod
+ def execute(cls, mask) -> IO.NodeOutput:
out = 1.0 - mask
- return (out,)
+ return IO.NodeOutput(out)
-class CropMask:
+ invert = execute # TODO: remove
+
+
+class CropMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="CropMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "crop"
-
- def crop(self, mask, x, y, width, height):
+ @classmethod
+ def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = mask[:, y:y + height, x:x + width]
- return (out,)
+ return IO.NodeOutput(out)
-class MaskComposite:
+ crop = execute # TODO: remove
+
+
+class MaskComposite(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "destination": ("MASK",),
- "source": ("MASK",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskComposite",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("destination"),
+ IO.Mask.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "combine"
-
- def combine(self, destination, source, x, y, operation):
+ @classmethod
+ def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
@@ -267,28 +277,29 @@ class MaskComposite:
output = torch.clamp(output, 0.0, 1.0)
- return (output,)
+ return IO.NodeOutput(output)
-class FeatherMask:
+ combine = execute # TODO: remove
+
+
+class FeatherMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FeatherMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "feather"
-
- def feather(self, mask, left, top, right, bottom):
+ @classmethod
+ def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[-1])
@@ -312,26 +323,28 @@ class FeatherMask:
feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate
- return (output,)
+ return IO.NodeOutput(output)
-class GrowMask:
+ feather = execute # TODO: remove
+
+
+class GrowMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
- "tapered_corners": ("BOOLEAN", {"default": True}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="GrowMask",
+ display_name="Grow Mask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Boolean.Input("tapered_corners", default=True),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "expand_mask"
-
- def expand_mask(self, mask, expand, tapered_corners):
+ @classmethod
+ def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
@@ -347,69 +360,74 @@ class GrowMask:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
- return (torch.stack(out, dim=0),)
+ return IO.NodeOutput(torch.stack(out, dim=0))
-class ThresholdMask:
+ expand_mask = execute # TODO: remove
+
+
+class ThresholdMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "mask": ("MASK",),
- "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ThresholdMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, mask, value):
+ @classmethod
+ def execute(cls, mask, value) -> IO.NodeOutput:
mask = (mask > value).float()
- return (mask,)
+ return IO.NodeOutput(mask)
+
+ image_to_mask = execute # TODO: remove
+
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
-class MaskPreview(nodes.SaveImage):
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
- self.type = "temp"
- self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
- self.compress_level = 4
+class MaskPreview(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskPreview",
+ display_name="Preview Mask",
+ category="mask",
+ description="Saves the input images to your ComfyUI output directory.",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {"mask": ("MASK",), },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
-
- FUNCTION = "execute"
- CATEGORY = "mask"
-
- def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
- preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
- return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
+ def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
+ return IO.NodeOutput(ui=UI.PreviewMask(mask))
-NODE_CLASS_MAPPINGS = {
- "LatentCompositeMasked": LatentCompositeMasked,
- "ImageCompositeMasked": ImageCompositeMasked,
- "MaskToImage": MaskToImage,
- "ImageToMask": ImageToMask,
- "ImageColorToMask": ImageColorToMask,
- "SolidMask": SolidMask,
- "InvertMask": InvertMask,
- "CropMask": CropMask,
- "MaskComposite": MaskComposite,
- "FeatherMask": FeatherMask,
- "GrowMask": GrowMask,
- "ThresholdMask": ThresholdMask,
- "MaskPreview": MaskPreview
-}
+class MaskExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ LatentCompositeMasked,
+ ImageCompositeMasked,
+ MaskToImage,
+ ImageToMask,
+ ImageColorToMask,
+ SolidMask,
+ InvertMask,
+ CropMask,
+ MaskComposite,
+ FeatherMask,
+ GrowMask,
+ ThresholdMask,
+ MaskPreview,
+ ]
-NODE_DISPLAY_NAME_MAPPINGS = {
- "ImageToMask": "Convert Image to Mask",
- "MaskToImage": "Convert Mask to Image",
-}
+
+async def comfy_entrypoint() -> MaskExtension:
+ return MaskExtension()
From d7a0aef65033bf0fe56e521577a44fac1830b8b3 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 5 Dec 2025 21:15:21 -0800
Subject: [PATCH 39/81] Set OCL_SET_SVM_SIZE on AMD. (#11139)
---
cuda_malloc.py | 27 +++++++++++++++++----------
main.py | 3 +++
2 files changed, 20 insertions(+), 10 deletions(-)
diff --git a/cuda_malloc.py b/cuda_malloc.py
index 6520d5123..ee2bc4b69 100644
--- a/cuda_malloc.py
+++ b/cuda_malloc.py
@@ -63,18 +63,22 @@ def cuda_malloc_supported():
return True
+version = ""
+
+try:
+ torch_spec = importlib.util.find_spec("torch")
+ for folder in torch_spec.submodule_search_locations:
+ ver_file = os.path.join(folder, "version.py")
+ if os.path.isfile(ver_file):
+ spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ version = module.__version__
+except:
+ pass
+
if not args.cuda_malloc:
try:
- version = ""
- torch_spec = importlib.util.find_spec("torch")
- for folder in torch_spec.submodule_search_locations:
- ver_file = os.path.join(folder, "version.py")
- if os.path.isfile(ver_file):
- spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- version = module.__version__
-
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
@@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
+
+def get_torch_version_noimport():
+ return str(version)
diff --git a/main.py b/main.py
index 0cd815d9e..0d02a087b 100644
--- a/main.py
+++ b/main.py
@@ -167,6 +167,9 @@ if __name__ == "__main__":
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
+ if "rocm" in cuda_malloc.get_torch_version_noimport():
+ os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
+
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
From 76f18e955dcbc88ed13d6802194fd897927f93e5 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Sat, 6 Dec 2025 13:28:08 +0200
Subject: [PATCH 40/81] marked all Pika API nodes a deprecated (#11146)
---
comfy_api_nodes/nodes_pika.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py
index 51148211b..acd88c391 100644
--- a/comfy_api_nodes/nodes_pika.py
+++ b/comfy_api_nodes/nodes_pika.py
@@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
From 7ac7d69d948e75c3a230d1262daab84d75aff895 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Sat, 6 Dec 2025 20:09:44 +0200
Subject: [PATCH 41/81] Fix EmptyAudio node input types (#11149)
---
comfy_extras/nodes_audio.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 812301fb7..c7916443c 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode):
step=0.01,
tooltip="Duration of the empty audio clip in seconds",
),
- IO.Float.Input(
+ IO.Int.Input(
"sample_rate",
default=44100,
tooltip="Sample rate of the empty audio clip.",
+ min=1,
+ max=192000,
),
- IO.Float.Input(
+ IO.Int.Input(
"channels",
default=2,
min=1,
From 50ca97e7765d9bbdbeec31a75f1f6c747d76948c Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 6 Dec 2025 15:36:20 -0800
Subject: [PATCH 42/81] Speed up lora compute and lower memory usage by doing
it in fp16. (#11161)
---
comfy/model_management.py | 14 ++++++++++++++
comfy/model_patcher.py | 5 +++--
2 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index aeddbaefe..40717b1e4 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1492,6 +1492,20 @@ def extended_fp16_support():
return True
+LORA_COMPUTE_DTYPES = {}
+def lora_compute_dtype(device):
+ dtype = LORA_COMPUTE_DTYPES.get(device, None)
+ if dtype is not None:
+ return dtype
+
+ if should_use_fp16(device):
+ dtype = torch.float16
+ else:
+ dtype = torch.float32
+
+ LORA_COMPUTE_DTYPES[device] = dtype
+ return dtype
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 215784874..4f076a6aa 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -614,10 +614,11 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
+ temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
- temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
+ temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
- temp_weight = weight.to(torch.float32, copy=True)
+ temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
From 4086acf3c2f0ca3a8861b04f6179fa9f908e3e25 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sun, 7 Dec 2025 09:42:09 +1000
Subject: [PATCH 43/81] Fix on-load VRAM OOM (#11144)
slow down the CPU on model load to not run ahead. This fixes a VRAM on
flux 2 load.
I went to try and debug this with the memory trace pickles, which needs
--disable-cuda-malloc which made the bug go away. So I tried this
synchronize and it worked.
The has some very complex interactions with the cuda malloc async and
I dont have solid theory on this one yet.
Still debugging but this gets us over the OOM for the moment.
---
comfy/model_patcher.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 4f076a6aa..5b1ccb824 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -762,6 +762,8 @@ class ModelPatcher:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
+ if comfy.model_management.is_device_cuda(device_to):
+ torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
From 329480da5ab32949a411548f821ea60ab3e90dc7 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 6 Dec 2025 17:50:10 -0800
Subject: [PATCH 44/81] Fix qwen scaled fp8 not working with kandinsky. Make
basic t2i wf work. (#11162)
---
comfy/ldm/kandinsky5/model.py | 8 +++++++-
comfy/text_encoders/kandinsky5.py | 12 ++++++------
2 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py
index a653e02fc..1509de2f8 100644
--- a/comfy/ldm/kandinsky5/model.py
+++ b/comfy/ldm/kandinsky5/model.py
@@ -387,6 +387,9 @@ class Kandinsky5(nn.Module):
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ original_dims = x.ndim
+ if original_dims == 4:
+ x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
@@ -397,7 +400,10 @@ class Kandinsky5(nn.Module):
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
- return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
+ out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
+ if original_dims == 4:
+ out = out.squeeze(2)
+ return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py
index 22f991c36..be086458c 100644
--- a/comfy/text_encoders/kandinsky5.py
+++ b/comfy/text_encoders/kandinsky5.py
@@ -24,10 +24,10 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
- llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
- if llama_scaled_fp8 is not None:
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -56,12 +56,12 @@ class Kandinsky5TEModel(QwenImageTEModel):
else:
return super().load_sd(sd)
-def te(dtype_llama=None, llama_scaled_fp8=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["qwen_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
From 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sun, 7 Dec 2025 04:44:55 -0800
Subject: [PATCH 45/81] Properly load the newbie diffusion model. (#11172)
There is still one of the text encoders missing and I didn't actually test it.
---
comfy/ldm/lumina/model.py | 35 +++++++++++++++++++++++++++++++++++
comfy/model_base.py | 4 ++++
comfy/model_detection.py | 3 +++
3 files changed, 42 insertions(+)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index 6c24fed9b..c47df49ca 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -377,6 +377,7 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
+ clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@@ -447,6 +448,31 @@ class NextDiT(nn.Module):
),
)
+ self.clip_text_pooled_proj = None
+
+ if clip_text_dim is not None:
+ self.clip_text_dim = clip_text_dim
+ self.clip_text_pooled_proj = nn.Sequential(
+ operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ operation_settings.get("operations").Linear(
+ clip_text_dim,
+ clip_text_dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+ self.time_text_embed = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(dim, 1024) + clip_text_dim,
+ min(dim, 1024),
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@@ -585,6 +611,15 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+ if self.clip_text_pooled_proj is not None:
+ pooled = kwargs.get("clip_text_pooled", None)
+ if pooled is not None:
+ pooled = self.clip_text_pooled_proj(pooled)
+ else:
+ pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
+
+ adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
+
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 0be006cc2..6b8a8454d 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1110,6 +1110,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
+ clip_text_pooled = kwargs["pooled_output"] # Newbie
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
return out
class WAN21(BaseModel):
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 30b33a486..74c547427 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
+ ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
+ if ctd_weight is not None:
+ dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
From ec7f65187d85e22ea23345ce0d919e11768f255e Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Mon, 8 Dec 2025 11:21:41 +0200
Subject: [PATCH 46/81] chore(comfy_api): replace absolute imports with
relative (#11145)
---
comfy_api/latest/__init__.py | 8 ++++----
comfy_api/latest/_input/video_types.py | 2 +-
comfy_api/latest/_input_impl/video_types.py | 4 ++--
comfy_api/latest/_io.py | 2 +-
comfy_api/latest/_ui.py | 2 +-
comfy_api/latest/_util/video_types.py | 2 +-
6 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 0fa01d1e7..35e1ac853 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
-from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
-from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
-from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
+from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
+from ._input_impl import VideoFromFile, VideoFromComponents
+from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
- This should be used to initialize any global resources neeeded by the extension.
+ This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod
diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py
index 87c81d73a..e634a0311 100644
--- a/comfy_api/latest/_input/video_types.py
+++ b/comfy_api/latest/_input/video_types.py
@@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
-from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
+from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index a4cd3737d..ea35c6062 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
-from comfy_api.latest._input import AudioInput, VideoInput
+from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
-from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
+from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index d7cbe68cf..313a5af20 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
-from comfy_api.latest._resources import Resources, ResourcesLocal
+from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py
index 5a75a3aae..2babe209a 100644
--- a/comfy_api/latest/_ui.py
+++ b/comfy_api/latest/_ui.py
@@ -22,7 +22,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
-from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
+from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py
index c3e3d8e3a..fd3b5a510 100644
--- a/comfy_api/latest/_util/video_types.py
+++ b/comfy_api/latest/_util/video_types.py
@@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
-from comfy_api.latest._input import ImageInput, AudioInput
+from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
From 058f084371ef2ed0c456118dfdd3d0bfed17259b Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Mon, 8 Dec 2025 17:22:51 +0800
Subject: [PATCH 47/81] Update workflow templates to v0.7.51 (#11150)
* chore: update workflow templates to v0.7.50
* Update template to 0.7.51
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index f98848e20..12a7c1089 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
-comfyui-workflow-templates==0.7.25
+comfyui-workflow-templates==0.7.51
comfyui-embedded-docs==0.3.1
torch
torchsde
From 85c4b4ae262c2de360891dd23c6504da2f5a6014 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Mon, 8 Dec 2025 11:27:02 +0200
Subject: [PATCH 48/81] chore: replace imports of deprecated V1 classes
(#11127)
---
comfy_api_nodes/apis/veo_api.py | 2 +-
comfy_api_nodes/nodes_gemini.py | 19 ++++++++++---------
comfy_api_nodes/nodes_ltxv.py | 17 +++++++----------
comfy_api_nodes/nodes_moonvalley.py | 19 ++++++++-----------
comfy_api_nodes/nodes_runway.py | 29 +++++++++++++----------------
comfy_api_nodes/nodes_veo2.py | 12 +++++-------
comfy_extras/nodes_video.py | 27 +++++++++++----------------
7 files changed, 55 insertions(+), 70 deletions(-)
diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py
index 8328d1aa4..23ca725b7 100644
--- a/comfy_api_nodes/apis/veo_api.py
+++ b/comfy_api_nodes/apis/veo_api.py
@@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
- videos: Optional[list[Video]] = None
+ videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index 08f7b0f64..0b7422ef7 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -13,8 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
-from comfy_api.latest import IO, ComfyExtension, Input
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@@ -68,7 +67,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
- images: torch.Tensor,
+ images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@@ -154,8 +153,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
-def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
- image_tensors: list[torch.Tensor] = []
+def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
+ image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@@ -293,7 +292,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
- base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
+ base_64_string = video_to_base64_string(
+ video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
+ )
return [
GeminiPart(
inlineData=GeminiInlineData(
@@ -343,7 +344,7 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
@@ -542,7 +543,7 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
@@ -662,7 +663,7 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py
index 0b757a62b..7e61560dc 100644
--- a/comfy_api_nodes/nodes_ltxv.py
+++ b/comfy_api_nodes/nodes_ltxv.py
@@ -1,12 +1,9 @@
from io import BytesIO
-from typing import Optional
-import torch
from pydantic import BaseModel, Field
from typing_extensions import override
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
- fps: Optional[int] = Field(25)
- generate_audio: Optional[bool] = Field(True)
- image_uri: Optional[str] = Field(None)
+ fps: int | None = Field(25)
+ generate_audio: bool | None = Field(True)
+ image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
- return IO.NodeOutput(VideoFromFile(BytesIO(response)))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
+ image: Input.Image,
model: str,
prompt: str,
duration: int,
@@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
- return IO.NodeOutput(VideoFromFile(BytesIO(response)))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py
index 7c31d95b3..2771e4790 100644
--- a/comfy_api_nodes/nodes_moonvalley.py
+++ b/comfy_api_nodes/nodes_moonvalley.py
@@ -1,11 +1,8 @@
import logging
-from typing import Optional
-import torch
from typing_extensions import override
-from comfy_api.input import VideoInput
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
-def validate_video_to_video_input(video: VideoInput) -> VideoInput:
+def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
-def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
+def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
-def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
+def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
-def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
+def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
+ image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
- video: Optional[VideoInput] = None,
+ video: Input.Video | None = None,
control_type: str = "Motion Transfer",
- motion_intensity: Optional[int] = 100,
+ motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py
index 2fdafbbfe..3c55039c9 100644
--- a/comfy_api_nodes/nodes_runway.py
+++ b/comfy_api_nodes/nodes_runway.py
@@ -11,12 +11,11 @@ User Guides:
"""
-from typing import Union, Optional
-from typing_extensions import override
from enum import Enum
-import torch
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
-def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status(
response: TaskStatusResponse,
-) -> Union[float, None]:
+) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
-def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
- cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
+ cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
- estimated_duration: Optional[int] = None,
-) -> VideoFromFile:
+ estimated_duration: int | None = None,
+) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
+ start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
+ start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
- end_frame: torch.Tensor,
+ start_frame: Input.Image,
+ end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
- reference_image: Optional[torch.Tensor] = None,
+ reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py
index a54dc13ab..e165b8380 100644
--- a/comfy_api_nodes/nodes_veo2.py
+++ b/comfy_api_nodes/nodes_veo2.py
@@ -1,11 +1,9 @@
import base64
from io import BytesIO
-import torch
from typing_extensions import override
-from comfy_api.input_impl.video_types import VideoFromFile
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
@@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
- return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
seed: int,
- first_frame: torch.Tensor,
- last_frame: torch.Tensor,
+ first_frame: Input.Image,
+ last_frame: Input.Image,
model: str,
generate_audio: bool,
):
@@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
- return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index 6cf6e39bf..c609e03da 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
-from comfy_api.input import AudioInput, ImageInput, VideoInput
-from comfy_api.input_impl import VideoFromComponents, VideoFromFile
-from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
-from comfy_api.latest import ComfyExtension, io, ui
+from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
- outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
- io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
- io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
+ io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
+ io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
- outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
- def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
+ def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
- file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
+ file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
- format=VideoContainer(format),
+ format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
- def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
+ def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
- VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
+ InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
- def execute(cls, video: VideoInput) -> io.NodeOutput:
+ def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
-
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
+
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
- return io.NodeOutput(VideoFromFile(video_path))
+ return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
From c3c6313fc7b24a5811efde7cfe10b7cbbea52663 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Mon, 8 Dec 2025 11:28:17 +0200
Subject: [PATCH 49/81] Added "system_prompt" input to Gemini nodes (#11177)
---
comfy_api_nodes/apis/gemini_api.py | 10 +-----
comfy_api_nodes/nodes_gemini.py | 52 ++++++++++++++++++++++++++++--
2 files changed, 51 insertions(+), 11 deletions(-)
diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py
index a380ecc86..f8edc38c9 100644
--- a/comfy_api_nodes/apis/gemini_api.py
+++ b/comfy_api_nodes/apis/gemini_api.py
@@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
- role: GeminiRole = Field(
- ...,
- description="The identity of the entity that creates the message. "
- "The following values are supported: "
- "user: This indicates that the message is sent by a real person, typically a user-generated message. "
- "model: This indicates that the message is generated by the model. "
- "The model value is used to insert messages from model into the conversation during multi-turn conversations. "
- "For non-multi-turn conversations, this field can be left blank or unset.",
- )
+ role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index 0b7422ef7..ad0f4b4d1 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -26,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
+ GeminiSystemInstructionContent,
+ GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@@ -42,6 +44,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
+GEMINI_IMAGE_SYS_PROMPT = (
+ "You are an expert image-generation engine. You must ALWAYS produce an image.\n"
+ "Interpret all user input—regardless of "
+ "format, intent, or abstraction—as literal visual directives for image composition.\n"
+ "If a prompt is conversational or lacks specific visual details, "
+ "you must creatively invent a concrete visual scenario that depicts the concept.\n"
+ "Prioritize generating the visual representation above any text, formatting, or conversational requests."
+)
class GeminiModel(str, Enum):
@@ -276,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default="",
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.String.Output(),
@@ -348,6 +365,7 @@ class GeminiNode(IO.ComfyNode):
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@@ -364,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
- # Create response
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -374,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
- ]
+ ],
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@@ -524,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default=GEMINI_IMAGE_SYS_PROMPT,
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.Image.Output(),
@@ -547,6 +576,7 @@ class GeminiImage(IO.ComfyNode):
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@@ -560,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -571,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@@ -641,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default=GEMINI_IMAGE_SYS_PROMPT,
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.Image.Output(),
@@ -665,6 +707,7 @@ class GeminiImage2(IO.ComfyNode):
response_modalities: str,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@@ -680,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -691,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
From fd271dedfde6e192a1f1a025521070876e89e04a Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Mon, 8 Dec 2025 11:33:46 +0200
Subject: [PATCH 50/81] [API Nodes] add support for seedance-1-0-pro-fast model
(#10947)
* feat(api-nodes): add support for seedance-1-0-pro-fast model
* feat(api-nodes): add support for seedream-4.5 model
---
comfy_api_nodes/apis/bytedance_api.py | 144 +++++++++++++++
comfy_api_nodes/nodes_bytedance.py | 255 ++++++--------------------
2 files changed, 196 insertions(+), 203 deletions(-)
create mode 100644 comfy_api_nodes/apis/bytedance_api.py
diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py
new file mode 100644
index 000000000..77cd76f9b
--- /dev/null
+++ b/comfy_api_nodes/apis/bytedance_api.py
@@ -0,0 +1,144 @@
+from typing import Literal
+
+from pydantic import BaseModel, Field
+
+
+class Text2ImageTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str | None = Field("url")
+ size: str | None = Field(None)
+ seed: int | None = Field(0, ge=0, le=2147483647)
+ guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
+ watermark: bool | None = Field(True)
+
+
+class Image2ImageTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str | None = Field("url")
+ image: str = Field(..., description="Base64 encoded string or image URL")
+ size: str | None = Field("adaptive")
+ seed: int | None = Field(..., ge=0, le=2147483647)
+ guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
+ watermark: bool | None = Field(True)
+
+
+class Seedream4Options(BaseModel):
+ max_images: int = Field(15)
+
+
+class Seedream4TaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str = Field("url")
+ image: list[str] | None = Field(None, description="Image URLs")
+ size: str = Field(...)
+ seed: int = Field(..., ge=0, le=2147483647)
+ sequential_image_generation: str = Field("disabled")
+ sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
+ watermark: bool = Field(True)
+
+
+class ImageTaskCreationResponse(BaseModel):
+ model: str = Field(...)
+ created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
+ data: list = Field([], description="Contains information about the generated image(s).")
+ error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
+
+
+class TaskTextContent(BaseModel):
+ type: str = Field("text")
+ text: str = Field(...)
+
+
+class TaskImageContentUrl(BaseModel):
+ url: str = Field(...)
+
+
+class TaskImageContent(BaseModel):
+ type: str = Field("image_url")
+ image_url: TaskImageContentUrl = Field(...)
+ role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
+
+
+class Text2VideoTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ content: list[TaskTextContent] = Field(..., min_length=1)
+
+
+class Image2VideoTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
+
+
+class TaskCreationResponse(BaseModel):
+ id: str = Field(...)
+
+
+class TaskStatusError(BaseModel):
+ code: str = Field(...)
+ message: str = Field(...)
+
+
+class TaskStatusResult(BaseModel):
+ video_url: str = Field(...)
+
+
+class TaskStatusResponse(BaseModel):
+ id: str = Field(...)
+ model: str = Field(...)
+ status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
+ error: TaskStatusError | None = Field(None)
+ content: TaskStatusResult | None = Field(None)
+
+
+RECOMMENDED_PRESETS = [
+ ("1024x1024 (1:1)", 1024, 1024),
+ ("864x1152 (3:4)", 864, 1152),
+ ("1152x864 (4:3)", 1152, 864),
+ ("1280x720 (16:9)", 1280, 720),
+ ("720x1280 (9:16)", 720, 1280),
+ ("832x1248 (2:3)", 832, 1248),
+ ("1248x832 (3:2)", 1248, 832),
+ ("1512x648 (21:9)", 1512, 648),
+ ("2048x2048 (1:1)", 2048, 2048),
+ ("Custom", None, None),
+]
+
+RECOMMENDED_PRESETS_SEEDREAM_4 = [
+ ("2048x2048 (1:1)", 2048, 2048),
+ ("2304x1728 (4:3)", 2304, 1728),
+ ("1728x2304 (3:4)", 1728, 2304),
+ ("2560x1440 (16:9)", 2560, 1440),
+ ("1440x2560 (9:16)", 1440, 2560),
+ ("2496x1664 (3:2)", 2496, 1664),
+ ("1664x2496 (2:3)", 1664, 2496),
+ ("3024x1296 (21:9)", 3024, 1296),
+ ("4096x4096 (1:1)", 4096, 4096),
+ ("Custom", None, None),
+]
+
+# The time in this dictionary are given for 10 seconds duration.
+VIDEO_TASKS_EXECUTION_TIME = {
+ "seedance-1-0-lite-t2v-250428": {
+ "480p": 40,
+ "720p": 60,
+ "1080p": 90,
+ },
+ "seedance-1-0-lite-i2v-250428": {
+ "480p": 40,
+ "720p": 60,
+ "1080p": 90,
+ },
+ "seedance-1-0-pro-250528": {
+ "480p": 70,
+ "720p": 85,
+ "1080p": 115,
+ },
+ "seedance-1-0-pro-fast-251015": {
+ "480p": 50,
+ "720p": 65,
+ "1080p": 100,
+ },
+}
diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py
index caced471e..57c0218d0 100644
--- a/comfy_api_nodes/nodes_bytedance.py
+++ b/comfy_api_nodes/nodes_bytedance.py
@@ -1,13 +1,27 @@
import logging
import math
-from enum import Enum
-from typing import Literal, Optional, Union
import torch
-from pydantic import BaseModel, Field
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.bytedance_api import (
+ RECOMMENDED_PRESETS,
+ RECOMMENDED_PRESETS_SEEDREAM_4,
+ VIDEO_TASKS_EXECUTION_TIME,
+ Image2ImageTaskCreationRequest,
+ Image2VideoTaskCreationRequest,
+ ImageTaskCreationResponse,
+ Seedream4Options,
+ Seedream4TaskCreationRequest,
+ TaskCreationResponse,
+ TaskImageContent,
+ TaskImageContentUrl,
+ TaskStatusResponse,
+ TaskTextContent,
+ Text2ImageTaskCreationRequest,
+ Text2VideoTaskCreationRequest,
+)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
-class Text2ImageModelName(str, Enum):
- seedream_3 = "seedream-3-0-t2i-250415"
-
-
-class Image2ImageModelName(str, Enum):
- seededit_3 = "seededit-3-0-i2i-250628"
-
-
-class Text2VideoModelName(str, Enum):
- seedance_1_pro = "seedance-1-0-pro-250528"
- seedance_1_lite = "seedance-1-0-lite-t2v-250428"
-
-
-class Image2VideoModelName(str, Enum):
- """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
-
- seedance_1_pro = "seedance-1-0-pro-250528"
- seedance_1_lite = "seedance-1-0-lite-i2v-250428"
-
-
-class Text2ImageTaskCreationRequest(BaseModel):
- model: Text2ImageModelName = Text2ImageModelName.seedream_3
- prompt: str = Field(...)
- response_format: Optional[str] = Field("url")
- size: Optional[str] = Field(None)
- seed: Optional[int] = Field(0, ge=0, le=2147483647)
- guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
- watermark: Optional[bool] = Field(True)
-
-
-class Image2ImageTaskCreationRequest(BaseModel):
- model: Image2ImageModelName = Image2ImageModelName.seededit_3
- prompt: str = Field(...)
- response_format: Optional[str] = Field("url")
- image: str = Field(..., description="Base64 encoded string or image URL")
- size: Optional[str] = Field("adaptive")
- seed: Optional[int] = Field(..., ge=0, le=2147483647)
- guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
- watermark: Optional[bool] = Field(True)
-
-
-class Seedream4Options(BaseModel):
- max_images: int = Field(15)
-
-
-class Seedream4TaskCreationRequest(BaseModel):
- model: str = Field("seedream-4-0-250828")
- prompt: str = Field(...)
- response_format: str = Field("url")
- image: Optional[list[str]] = Field(None, description="Image URLs")
- size: str = Field(...)
- seed: int = Field(..., ge=0, le=2147483647)
- sequential_image_generation: str = Field("disabled")
- sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
- watermark: bool = Field(True)
-
-
-class ImageTaskCreationResponse(BaseModel):
- model: str = Field(...)
- created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
- data: list = Field([], description="Contains information about the generated image(s).")
- error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
-
-
-class TaskTextContent(BaseModel):
- type: str = Field("text")
- text: str = Field(...)
-
-
-class TaskImageContentUrl(BaseModel):
- url: str = Field(...)
-
-
-class TaskImageContent(BaseModel):
- type: str = Field("image_url")
- image_url: TaskImageContentUrl = Field(...)
- role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
-
-
-class Text2VideoTaskCreationRequest(BaseModel):
- model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
- content: list[TaskTextContent] = Field(..., min_length=1)
-
-
-class Image2VideoTaskCreationRequest(BaseModel):
- model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
- content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
-
-
-class TaskCreationResponse(BaseModel):
- id: str = Field(...)
-
-
-class TaskStatusError(BaseModel):
- code: str = Field(...)
- message: str = Field(...)
-
-
-class TaskStatusResult(BaseModel):
- video_url: str = Field(...)
-
-
-class TaskStatusResponse(BaseModel):
- id: str = Field(...)
- model: str = Field(...)
- status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
- error: Optional[TaskStatusError] = Field(None)
- content: Optional[TaskStatusResult] = Field(None)
-
-
-RECOMMENDED_PRESETS = [
- ("1024x1024 (1:1)", 1024, 1024),
- ("864x1152 (3:4)", 864, 1152),
- ("1152x864 (4:3)", 1152, 864),
- ("1280x720 (16:9)", 1280, 720),
- ("720x1280 (9:16)", 720, 1280),
- ("832x1248 (2:3)", 832, 1248),
- ("1248x832 (3:2)", 1248, 832),
- ("1512x648 (21:9)", 1512, 648),
- ("2048x2048 (1:1)", 2048, 2048),
- ("Custom", None, None),
-]
-
-RECOMMENDED_PRESETS_SEEDREAM_4 = [
- ("2048x2048 (1:1)", 2048, 2048),
- ("2304x1728 (4:3)", 2304, 1728),
- ("1728x2304 (3:4)", 1728, 2304),
- ("2560x1440 (16:9)", 2560, 1440),
- ("1440x2560 (9:16)", 1440, 2560),
- ("2496x1664 (3:2)", 2496, 1664),
- ("1664x2496 (2:3)", 1664, 2496),
- ("3024x1296 (21:9)", 3024, 1296),
- ("4096x4096 (1:1)", 4096, 4096),
- ("Custom", None, None),
-]
-
-# The time in this dictionary are given for 10 seconds duration.
-VIDEO_TASKS_EXECUTION_TIME = {
- "seedance-1-0-lite-t2v-250428": {
- "480p": 40,
- "720p": 60,
- "1080p": 90,
- },
- "seedance-1-0-lite-i2v-250428": {
- "480p": 40,
- "720p": 60,
- "1080p": 90,
- },
- "seedance-1-0-pro-250528": {
- "480p": 70,
- "720p": 85,
- "1080p": 115,
- },
-}
-
-
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
-def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
- """Returns the video URL from the task status response if it exists."""
- if hasattr(response, "content") and response.content:
- return response.content.video_url
- return None
-
-
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
- IO.Combo.Input(
- "model",
- options=Text2ImageModelName,
- default=Text2ImageModelName.seedream_3,
- tooltip="Model name",
- ),
+ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
- IO.Combo.Input(
- "model",
- options=Image2ImageModelName,
- default=Image2ImageModelName.seededit_3,
- tooltip="Model name",
- ),
+ IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
- image: torch.Tensor,
+ image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=["seedream-4-0-250828"],
+ options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
- step=64,
+ step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
- step=64,
+ step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- image: torch.Tensor = None,
+ image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
+ out_num_pixels = w * h
+ mp_provided = out_num_pixels / 1_000_000.0
+ if "seedream-4-5" in model and out_num_pixels < 3686400:
+ raise ValueError(
+ f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
+ f"but {mp_provided:.2f}MP provided."
+ )
+ if "seedream-4-0" in model and out_num_pixels < 921600:
+ raise ValueError(
+ f"Minimum image resolution that the selected model can generate is 0.92MP, "
+ f"but {mp_provided:.2f}MP provided."
+ )
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=Text2VideoModelName,
- default=Text2VideoModelName.seedance_1_pro,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
+ default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=Image2VideoModelName,
- default=Image2VideoModelName.seedance_1_pro,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
+ default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- image: torch.Tensor,
+ image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=[model.value for model in Image2VideoModelName],
- default=Image2VideoModelName.seedance_1_lite.value,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
+ default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- first_frame: torch.Tensor,
- last_frame: torch.Tensor,
+ first_frame: Input.Image,
+ last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=[Image2VideoModelName.seedance_1_lite.value],
- default=Image2VideoModelName.seedance_1_lite.value,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
+ default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- images: torch.Tensor,
+ images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
- payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
- estimated_duration: Optional[int],
+ payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
+ estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@@ -1085,7 +934,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
- return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
+ return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
From 8e889c535d1fc407bf27dbf8359eef9580f2ed60 Mon Sep 17 00:00:00 2001
From: dxqb <183307934+dxqb@users.noreply.github.com>
Date: Mon, 8 Dec 2025 21:17:26 +0100
Subject: [PATCH 51/81] Support "transformer." LoRA prefix for Z-Image (#11135)
---
comfy/lora.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/comfy/lora.py b/comfy/lora.py
index e7202ce97..2ed0acb9d 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
+ key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):
From 60ee574748209a17ade1c7524e228be2802d1589 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 9 Dec 2025 06:18:06 +1000
Subject: [PATCH 52/81] retune lowVramPatch VRAM accounting (#11173)
In the lowvram case, this now does its math in the model dtype in the
post de-quantization domain. Account for that. The patching was also
put back on the compute stream getting it off-peak so relax the
MATH_FACTOR to only x2 so get out of the worst-case assumption of
everything peaking at once.
---
comfy/model_patcher.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 5b1ccb824..8b5edeb52 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -132,14 +132,14 @@ class LowVramPatch:
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
-#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
-LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
+LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
- return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
+ model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
+ return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
From 935493f6c186de8808508713a465d6bda75e5ce4 Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Tue, 9 Dec 2025 04:18:53 +0800
Subject: [PATCH 53/81] chore: update workflow templates to v0.7.54 (#11192)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 12a7c1089..4bd4b21c3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
-comfyui-workflow-templates==0.7.51
+comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
torchsde
From 3b0368aa34182fc7c97de92d59b609c77138def2 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 8 Dec 2025 14:38:36 -0800
Subject: [PATCH 54/81] Fix regression. (#11194)
---
comfy/model_patcher.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 8b5edeb52..a7d24ac13 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -139,6 +139,9 @@ def low_vram_patch_estimate_vram(model, key):
if weight is None:
return 0
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
+ if model_dtype is None:
+ model_dtype = weight.dtype
+
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
From d50f342c90802830c1178ad9d7f2783dc2821af1 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 8 Dec 2025 20:20:04 -0800
Subject: [PATCH 55/81] Fix potential issue. (#11201)
---
comfy/model_patcher.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index a7d24ac13..2e8ce2613 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -923,7 +923,7 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
- if cast_weight:
+ if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
From e136b6dbb0b08341388f5bf9a00b1fca29992eb3 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 9 Dec 2025 14:21:31 +1000
Subject: [PATCH 56/81] dequantization offload accounting (fixes Flux2 OOMs -
incl TEs) (#11171)
* make setattr safe for non existent attributes
Handle the case where the attribute doesnt exist by returning a static
sentinel (distinct from None). If the sentinel is passed in as the set
value, del the attr.
* Account for dequantization and type-casts in offload costs
When measuring the cost of offload, identify weights that need a type
change or dequantization and add the size of the conversion result
to the offload cost.
This is mutually exclusive with lowvram patches which already has
a large conservative estimate and wont overlap the dequant cost so\
dont double count.
* Set the compute type on CLIP MPs
So that the loader can know the size of weights for dequant accounting.
---
comfy/model_patcher.py | 19 +++++++++++++------
comfy/sd.py | 2 ++
comfy/utils.py | 9 +++++++--
3 files changed, 22 insertions(+), 8 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 2e8ce2613..a486c2723 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
+from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -665,12 +666,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
- weight_key = "{}.weight".format(n)
- bias_key = "{}.bias".format(n)
- if weight_key in self.patches:
- module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
- if bias_key in self.patches:
- module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
+ def check_module_offload_mem(key):
+ if key in self.patches:
+ return low_vram_patch_estimate_vram(self.model, key)
+ model_dtype = getattr(self.model, "manual_cast_dtype", None)
+ weight, _, _ = get_key_weight(self.model, key)
+ if model_dtype is None or weight is None:
+ return 0
+ if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
+ return weight.numel() * model_dtype.itemsize
+ return 0
+ module_offload_mem += check_module_offload_mem("{}.weight".format(n))
+ module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
diff --git a/comfy/sd.py b/comfy/sd.py
index 754b1703d..a16f2d14f 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -127,6 +127,8 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
+ #Match torch.float32 hardcode upcast in TE implemention
+ self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
diff --git a/comfy/utils.py b/comfy/utils.py
index 89846bc95..9dc0d76ac 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
+ATTR_UNSET={}
+
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
- prev = getattr(obj, attrs[-1])
- setattr(obj, attrs[-1], value)
+ prev = getattr(obj, attrs[-1], ATTR_UNSET)
+ if value is ATTR_UNSET:
+ delattr(obj, attrs[-1])
+ else:
+ setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
From cabc4d351ff620ece87f18019d98131ebcbdf1aa Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Mon, 8 Dec 2025 20:22:02 -0800
Subject: [PATCH 57/81] bump comfyui-frontend-package to 1.33.13 (patch)
(#11200)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 4bd4b21c3..11a7ac245 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.33.10
+comfyui-frontend-package==1.33.13
comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
From b9fb542703085c58f082b4a822329fb6670e8016 Mon Sep 17 00:00:00 2001
From: Lodestone
Date: Tue, 9 Dec 2025 11:33:29 +0700
Subject: [PATCH 58/81] add chroma-radiance-x0 mode (#11197)
---
comfy/ldm/chroma_radiance/model.py | 20 ++++++++++++++++++--
comfy/model_detection.py | 2 ++
2 files changed, 20 insertions(+), 2 deletions(-)
diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py
index e643b4414..70d173889 100644
--- a/comfy/ldm/chroma_radiance/model.py
+++ b/comfy/ldm/chroma_radiance/model.py
@@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
-
+ use_x0: bool
class ChromaRadiance(Chroma):
"""
@@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = []
self.lite = False
+ if params.use_x0:
+ self.register_buffer("__x0__", torch.tensor([]))
+
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
+ def _apply_x0_residual(self, predicted, noisy, timesteps):
+
+ # non zero during training to prevent 0 div
+ eps = 0.0
+ return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
+
def _forward(
self,
x: Tensor,
@@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
- return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
+
+ out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
+
+ # If x0 variant → v-pred, just return this instead
+ if hasattr(self, "__x0__"):
+ out = self._apply_x0_residual(out, img, timestep)
+ return out
+
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 74c547427..19e6aa954 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
+ if "__x0__" in state_dict_keys: # x0 pred
+ dit_config["use_x0"] = True
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
From 9d252f3b70c0e89cbb581e28bb1862593c4e5ceb Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 9 Dec 2025 15:55:13 +1000
Subject: [PATCH 59/81] ops: delete dead code (#11204)
This became dead code in https://github.com/comfyanonymous/ComfyUI/pull/11069
---
comfy/ops.py | 8 --------
1 file changed, 8 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 35237c9f7..6f34d50fc 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
-import contextlib
import json
def run_every_op():
@@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None
- if offload_stream is not None:
- wf_context = offload_stream
- if hasattr(wf_context, "as_context"):
- wf_context = wf_context.as_context(offload_stream)
- else:
- wf_context = contextlib.nullcontext()
-
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
From e2a800e7ef225260c078ce484c75bb40161d9d94 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Tue, 9 Dec 2025 23:59:16 +0200
Subject: [PATCH 60/81] Fix for HunyuanVideo1.5 meanflow distil (#11212)
---
comfy/ldm/hunyuan_video/model.py | 3 ++-
comfy/model_detection.py | 2 ++
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index 2749c53f5..55ab550f8 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -43,6 +43,7 @@ class HunyuanVideoParams:
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
+ meanflow_sum: bool
class SelfAttentionRef(nn.Module):
@@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
- vec = (vec + vec_r) / 2
+ vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 19e6aa954..1f5d34bdd 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
+ dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
+ dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
From 791e30ff5037fa5e7aa4e1396099ea8d6bfb020b Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 9 Dec 2025 14:03:21 -0800
Subject: [PATCH 61/81] Fix nan issue when quantizing fp16 tensor. (#11213)
---
comfy/quant_ops.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index 571d3f760..cd96541d7 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -399,7 +399,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
orig_dtype = tensor.dtype
if isinstance(scale, str) and scale == "recalculate":
- scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
+ scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
+ if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
+ tensor_info = torch.finfo(tensor.dtype)
+ scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is not None:
if not isinstance(scale, torch.Tensor):
From fc657f471a29d07696ca16b566000e8e555d67d1 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Tue, 9 Dec 2025 18:22:09 -0500
Subject: [PATCH 62/81] ComfyUI version v0.4.0
From now on ComfyUI will do version numbers a bit differently, every stable
off the master branch will increment the minor version. Anytime a fix needs
to be backported onto a stable version the patch version will be
incremented.
Example: We release v0.6.0 off the master branch then a day later a bug is
discovered and we decide to backport the fix onto the v0.6.0 stable, this
will be done in a separate branch in the main repository and this new
stable will be tagged v0.6.1
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index 4b039356e..2f083edaf 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.3.76"
+__version__ = "0.4.0"
diff --git a/pyproject.toml b/pyproject.toml
index 02b94a0ce..e4d3d616a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.3.76"
+version = "0.4.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
From f668c2e3c99df40561b416cf62b0fd9eec96007a Mon Sep 17 00:00:00 2001
From: Benjamin Lu
Date: Tue, 9 Dec 2025 19:27:07 -0800
Subject: [PATCH 63/81] bump comfyui-frontend-package to 1.34.8 (#11220)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 11a7ac245..9e9b25328 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.33.13
+comfyui-frontend-package==1.34.8
comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
From 36357bbcc3c515e37a742457a2b2ab4b7ccc17a8 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 10 Dec 2025 21:55:09 +0200
Subject: [PATCH 64/81] process the NodeV1 dict results correctly (#11237)
---
comfy_api/latest/_io.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 313a5af20..79217c813 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -1815,7 +1815,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
- return cls(args=args, ui=ui, expand=expand)
+ return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
From 17c92a9f2843d7b9b727531066be2378b350a6ae Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 10 Dec 2025 16:59:48 -0800
Subject: [PATCH 65/81] Tweak Z Image memory estimation. (#11254)
---
comfy/supported_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 383c82c3e..dd0f09f32 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1026,7 +1026,7 @@ class ZImage(Lumina2):
"shift": 3.0,
}
- memory_usage_factor = 1.7
+ memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
From 57ddb7fd13d817e7259c2c992a852832b6b0f07a Mon Sep 17 00:00:00 2001
From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com>
Date: Thu, 11 Dec 2025 03:49:49 +0100
Subject: [PATCH 66/81] Fix: filter hidden files from /internal/files endpoint
(#11191)
---
api_server/routes/internal/internal_routes.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py
index 613b0f7c7..b224306da 100644
--- a/api_server/routes/internal/internal_routes.py
+++ b/api_server/routes/internal/internal_routes.py
@@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
+
+ def is_visible_file(entry: os.DirEntry) -> bool:
+ """Filter out hidden files (e.g., .DS_Store on macOS)."""
+ return entry.is_file() and not entry.name.startswith('.')
+
sorted_files = sorted(
- (entry for entry in os.scandir(directory) if entry.is_file()),
+ (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
From e711aaf1a75120195c56ebd1f1ce829c6b7b84db Mon Sep 17 00:00:00 2001
From: Farshore <168402472+jiangchengchengark@users.noreply.github.com>
Date: Thu, 11 Dec 2025 11:02:26 +0800
Subject: [PATCH 67/81] =?UTF-8?q?Lower=20VAE=20loading=20requirements?=
=?UTF-8?q?=EF=BC=9ACreate=20a=20new=20branch=20for=20GPU=20memory=20calcu?=
=?UTF-8?q?lations=20in=20qwen-image=20vae=20(#11199)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
comfy/sd.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index a16f2d14f..1cad98aef 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -549,8 +549,10 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
- self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
- self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
+
+
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
From 93948e3fc598c14082f744fe82fae056b64ff481 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Thu, 11 Dec 2025 08:11:12 +0200
Subject: [PATCH 68/81] feat(api-nodes): enable Kling Omni O1 node (#11229)
---
comfy_api_nodes/nodes_kling.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 6c840dc47..a2cc87d84 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -2056,7 +2056,7 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
- # OmniProImageNode, # need support from backend
+ OmniProImageNode,
]
From f8321eb57b29a4b34cecd27d5d6365adf5e6e601 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 10 Dec 2025 22:30:31 -0800
Subject: [PATCH 69/81] Adjust memory usage factor. (#11257)
---
comfy/supported_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index dd0f09f32..ef8c75c09 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -541,7 +541,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
- memory_usage_factor = 1.2
+ memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
From fdebe182966d1dd9bee3138264937137bd2302d8 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 11 Dec 2025 14:09:35 -0800
Subject: [PATCH 70/81] Fix regular chroma radiance (#11276)
---
comfy/model_detection.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 1f5d34bdd..94b54b7c2 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -261,6 +261,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
+ else:
+ dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
From ae65433a602470eea271df47af0eb871d146a002 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 11 Dec 2025 14:15:00 -0800
Subject: [PATCH 71/81] This only works on radiance. (#11277)
---
comfy/model_detection.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 94b54b7c2..dd6a703f6 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -259,10 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
- if "__x0__" in state_dict_keys: # x0 pred
- dit_config["use_x0"] = True
- else:
- dit_config["use_x0"] = False
+ if "__x0__" in state_dict_keys: # x0 pred
+ dit_config["use_x0"] = True
+ else:
+ dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
From eeb020b9b77e1f3c0c2806bc1e38c7ba9576439e Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 11 Dec 2025 14:33:09 -0800
Subject: [PATCH 72/81] Better chroma radiance and other models vram
estimation. (#11278)
---
comfy/supported_models.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index ef8c75c09..834dfcffc 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -965,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
- self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
+ self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@@ -1289,7 +1289,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
- memory_usage_factor = 0.038
+ memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@@ -1332,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
- memory_usage_factor = 1.65 #TODO
+ memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@@ -1397,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
- memory_usage_factor = 7.7
+ memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -1488,7 +1488,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
- memory_usage_factor = 1.1 #TODO
+ memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -1517,7 +1517,7 @@ class Kandinsky5Image(Kandinsky5):
}
latent_format = latent_formats.Flux
- memory_usage_factor = 1.1 #TODO
+ memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)
From 338d9ae3bbf24a9a06996cdf1c2f228acc65fd96 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 11 Dec 2025 15:56:33 -0800
Subject: [PATCH 73/81] Make portable updater work with repos in unmerged
state. (#11281)
---
.ci/update_windows/update.py | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py
index 59ece5130..fe646a6ed 100755
--- a/.ci/update_windows/update.py
+++ b/.ci/update_windows/update.py
@@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
+except:
+ print("Could not stash, cleaning index and trying again.") # noqa: T201
+ repo.state_cleanup()
+ repo.index.read_tree(repo.head.peel().tree)
+ repo.index.write()
+ try:
+ repo.stash(ident)
+ except KeyError:
+ print("nothing to stash.") # noqa: T201
+
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
From 982876d59a659adb085be5e236aacc4f2c54c19c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Fri, 12 Dec 2025 05:29:34 +0200
Subject: [PATCH 74/81] WanMove support (#11247)
---
comfy_api/latest/_io.py | 8 +
comfy_extras/nodes_wanmove.py | 535 ++++++++++++++++++++++++++++++++++
nodes.py | 1 +
3 files changed, 544 insertions(+)
create mode 100644 comfy_extras/nodes_wanmove.py
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 79217c813..2b634d172 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
+@comfytype(io_type="TRACKS")
+class Tracks(ComfyTypeIO):
+ class TrackDict(TypedDict):
+ track_path: torch.Tensor
+ track_visibility: torch.Tensor
+ Type = TrackDict
+
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@@ -1894,6 +1901,7 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
+ "Tracks",
# Dynamic Types
"MatchType",
# "DynamicCombo",
diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py
new file mode 100644
index 000000000..5f39afa46
--- /dev/null
+++ b/comfy_extras/nodes_wanmove.py
@@ -0,0 +1,535 @@
+import nodes
+import node_helpers
+import torch
+import torchvision.transforms.functional as TF
+import comfy.model_management
+import comfy.utils
+import numpy as np
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+from comfy_extras.nodes_wan import parse_json_tracks
+
+# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
+from PIL import Image, ImageDraw
+
+SKIP_ZERO = False
+
+def get_pos_emb(
+ pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
+ pos_emb_dim: int,
+ theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
+
+ assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
+ pos_k = pos_k.to(device, dtype)
+ if SKIP_ZERO:
+ pos_k = pos_k + 1
+ batch_size = pos_k.size(0)
+
+ denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
+ # Expand denominator to match the shape needed for broadcasting
+ denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
+
+ thetas = theta_func(denominator_expanded, pos_emb_dim)
+
+ # Ensure pos_k is in the correct shape for broadcasting
+ pos_k_expanded = pos_k.view(-1, 1).to(dtype)
+ sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
+ cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
+
+ # Concatenate sine and cosine embeddings along the last dimension
+ pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
+
+ return pos_emb
+
+def create_pos_embeddings(
+ pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
+ pred_visibility: torch.Tensor, # the predicted visibility [T, N]
+ downsample_ratios: list[int], # the ratios for downsampling time, height, and width
+ height: int, # the height of the feature map
+ width: int, # the width of the feature map
+ track_num: int = -1, # the number of tracks to use
+ t_down_strategy: str = "sample", # the strategy for downsampling time dimension
+):
+ assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
+
+ t, n, _ = pred_tracks.shape
+ t_down, h_down, w_down = downsample_ratios
+ track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
+
+ if track_num == -1:
+ track_num = n
+
+ tracks_idx = torch.randperm(n)[:track_num]
+ tracks = pred_tracks[:, tracks_idx]
+ visibility = pred_visibility[:, tracks_idx]
+
+ for t_idx in range(0, t, t_down):
+ if t_down_strategy == "sample" or t_idx == 0:
+ cur_tracks = tracks[t_idx] # [N, 2]
+ cur_visibility = visibility[t_idx] # [N]
+ else:
+ cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
+ cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
+
+ for i in range(track_num):
+ if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
+ continue
+ x, y = cur_tracks[i]
+ x, y = int(x // w_down), int(y // h_down)
+ track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
+
+ return track_pos # the position embeddings, [N, T', 2], 2 = height, width
+
+def replace_feature(
+ vae_feature: torch.Tensor, # [B, C', T', H', W']
+ track_pos: torch.Tensor, # [B, N, T', 2]
+ strength: float = 1.0
+) -> torch.Tensor:
+ b, _, t, h, w = vae_feature.shape
+ assert b == track_pos.shape[0], "Batch size mismatch."
+ n = track_pos.shape[1]
+
+ # Shuffle the trajectory order
+ track_pos = track_pos[:, torch.randperm(n), :, :]
+
+ # Extract coordinates at time steps ≥ 1 and generate a valid mask
+ current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
+ mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
+
+ # Get all valid indices
+ valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
+ num_valid = valid_indices.shape[0]
+
+ if num_valid == 0:
+ return vae_feature
+
+ # Decompose valid indices into each dimension
+ batch_idx = valid_indices[:, 0]
+ track_idx = valid_indices[:, 1]
+ t_rel = valid_indices[:, 2]
+ t_target = t_rel + 1 # Convert to original time step indices
+
+ # Extract target position coordinates
+ h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
+ w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
+
+ # Extract source position coordinates (t=0)
+ h_source = track_pos[batch_idx, track_idx, 0, 0].long()
+ w_source = track_pos[batch_idx, track_idx, 0, 1].long()
+
+ # Get source features and assign to target positions
+ src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
+ dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
+
+ vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
+
+
+ return vae_feature
+
+# Visualize functions
+
+def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
+ draw = ImageDraw.Draw(overlay, 'RGBA')
+ points = points[::-1]
+
+ # Compute total length
+ total_length = 0
+ segment_lengths = []
+ for i in range(len(points) - 1):
+ dx = points[i + 1][0] - points[i][0]
+ dy = points[i + 1][1] - points[i][1]
+ length = (dx * dx + dy * dy) ** 0.5
+ segment_lengths.append(length)
+ total_length += length
+
+ if total_length == 0:
+ return
+
+ accumulated_length = 0
+
+ # Draw the gradient polyline
+ for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
+ segment_length = segment_lengths[idx]
+ steps = max(int(segment_length), 1)
+
+ for i in range(steps):
+ current_length = accumulated_length + (i / steps) * segment_length
+ ratio = current_length / total_length
+
+ alpha = int(255 * (1 - ratio) * opacity)
+ color = (*start_color, alpha)
+
+ x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
+ y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
+
+ dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
+ draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
+
+ accumulated_length += segment_length
+
+
+def add_weighted(rgb, track):
+ rgb = np.array(rgb) # [H, W, C] "RGB"
+ track = np.array(track) # [H, W, C] "RGBA"
+
+ alpha = track[:, :, 3] / 255.0
+ alpha = np.stack([alpha] * 3, axis=-1)
+ blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
+
+ return Image.fromarray(blend_img.astype(np.uint8))
+
+def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
+ color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
+
+ video = video.byte().cpu().numpy() # (81, 480, 832, 3)
+ tracks = tracks[0].long().detach().cpu().numpy()
+ if visibility is not None:
+ visibility = visibility[0].detach().cpu().numpy()
+
+ num_frames, height, width = video.shape[:3]
+ num_tracks = tracks.shape[1]
+ alpha_opacity = int(255 * opacity)
+
+ output_frames = []
+ for t in range(num_frames):
+ frame_rgb = video[t].astype(np.float32)
+
+ # Create a single RGBA overlay for all tracks in this frame
+ overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
+ draw_overlay = ImageDraw.Draw(overlay)
+
+ polyline_data = []
+
+ # Draw all circles on a single overlay
+ for n in range(num_tracks):
+ if visibility is not None and visibility[t, n] == 0:
+ continue
+
+ track_coord = tracks[t, n]
+ color = color_map[n % len(color_map)]
+ circle_color = color + (alpha_opacity,)
+
+ draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
+ fill=circle_color
+ )
+
+ # Store polyline data for batch processing
+ tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
+ if len(tracks_coord) > 1:
+ polyline_data.append((tracks_coord, color))
+
+ # Blend circles overlay once
+ overlay_np = np.array(overlay)
+ alpha = overlay_np[:, :, 3:4] / 255.0
+ frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
+
+ # Draw all polylines on a single overlay
+ if polyline_data:
+ polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
+ for tracks_coord, color in polyline_data:
+ _draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
+
+ # Blend polylines overlay once
+ polyline_np = np.array(polyline_overlay)
+ alpha = polyline_np[:, :, 3:4] / 255.0
+ frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
+
+ output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
+
+ return output_frames
+
+
+class WanMoveVisualizeTracks(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="WanMoveVisualizeTracks",
+ category="conditioning/video_models",
+ inputs=[
+ io.Image.Input("images"),
+ io.Tracks.Input("tracks", optional=True),
+ io.Int.Input("line_resolution", default=24, min=1, max=1024),
+ io.Int.Input("circle_size", default=12, min=1, max=128),
+ io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
+ io.Int.Input("line_width", default=16, min=1, max=128),
+ ],
+ outputs=[
+ io.Image.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
+ if tracks is None:
+ return io.NodeOutput(images)
+
+ track_path = tracks["track_path"].unsqueeze(0)
+ track_visibility = tracks["track_visibility"].unsqueeze(0)
+ images_in = images * 255.0
+ if images_in.shape[0] != track_path.shape[1]:
+ repeat_count = track_path.shape[1] // images.shape[0]
+ images_in = images_in.repeat(repeat_count, 1, 1, 1)
+ track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
+ track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
+
+ return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
+
+
+class WanMoveTracksFromCoords(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="WanMoveTracksFromCoords",
+ category="conditioning/video_models",
+ inputs=[
+ io.String.Input("track_coords", force_input=True, default="[]", optional=True),
+ io.Mask.Input("track_mask", optional=True),
+ ],
+ outputs=[
+ io.Tracks.Output(),
+ io.Int.Output(display_name="track_length"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
+ device=comfy.model_management.intermediate_device()
+
+ tracks_data = parse_json_tracks(track_coords)
+ track_length = len(tracks_data[0])
+
+ track_list = [
+ [[track[frame]['x'], track[frame]['y']] for track in tracks_data]
+ for frame in range(len(tracks_data[0]))
+ ]
+ tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
+
+ num_tracks = tracks.shape[-2]
+ if track_mask is None:
+ track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
+ else:
+ track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
+
+ out_track_info = {}
+ out_track_info["track_path"] = tracks
+ out_track_info["track_visibility"] = track_visibility
+ return io.NodeOutput(out_track_info, track_length)
+
+
+class GenerateTracks(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="GenerateTracks",
+ category="conditioning/video_models",
+ inputs=[
+ io.Int.Input("width", default=832, min=16, max=4096, step=16),
+ io.Int.Input("height", default=480, min=16, max=4096, step=16),
+ io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
+ io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
+ io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
+ io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
+ io.Int.Input("num_frames", default=81, min=1, max=1024),
+ io.Int.Input("num_tracks", default=5, min=1, max=100),
+ io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
+ io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
+ io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
+ io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
+ io.Combo.Input(
+ "interpolation",
+ options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
+ tooltip="Controls the timing/speed of movement along the path.",
+ ),
+ io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
+ ],
+ outputs=[
+ io.Tracks.Output(),
+ io.Int.Output(display_name="track_length"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
+ track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
+ device = comfy.model_management.intermediate_device()
+ track_length = num_frames
+
+ # normalized coordinates to pixel coordinates
+ start_x_px = start_x * width
+ start_y_px = start_y * height
+ mid_x_px = mid_x * width
+ mid_y_px = mid_y * height
+ end_x_px = end_x * width
+ end_y_px = end_y * height
+
+ track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
+
+ t = torch.linspace(0, 1, num_frames, device=device)
+ if interpolation == "constant": # All points stay at start position
+ interp_values = torch.zeros_like(t)
+ elif interpolation == "linear":
+ interp_values = t
+ elif interpolation == "ease_in":
+ interp_values = t ** 2
+ elif interpolation == "ease_out":
+ interp_values = 1 - (1 - t) ** 2
+ elif interpolation == "ease_in_out":
+ interp_values = t * t * (3 - 2 * t)
+
+ if bezier: # apply interpolation to t for timing control along the bezier path
+ t_interp = interp_values
+ one_minus_t = 1 - t_interp
+ x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
+ y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
+ tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
+ tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
+ else: # calculate base x and y positions for each frame (center track)
+ x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
+ y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
+ # For non-bezier, tangent is constant (direction from start to end)
+ tangent_x = torch.full_like(t, end_x_px - start_x_px)
+ tangent_y = torch.full_like(t, end_y_px - start_y_px)
+
+ track_list = []
+ for frame_idx in range(num_frames):
+ # Calculate perpendicular direction at this frame
+ tx = tangent_x[frame_idx].item()
+ ty = tangent_y[frame_idx].item()
+ length = (tx ** 2 + ty ** 2) ** 0.5
+
+ if length > 0: # Perpendicular unit vector (rotate 90 degrees)
+ perp_x = -ty / length
+ perp_y = tx / length
+ else: # If tangent is zero, spread horizontally
+ perp_x = 1.0
+ perp_y = 0.0
+
+ frame_tracks = []
+ for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
+ offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
+ track_x = x_positions[frame_idx].item() + perp_x * offset
+ track_y = y_positions[frame_idx].item() + perp_y * offset
+ frame_tracks.append([track_x, track_y])
+ track_list.append(frame_tracks)
+
+ tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
+
+ if track_mask is None:
+ track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
+ else:
+ track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
+
+ out_track_info = {}
+ out_track_info["track_path"] = tracks
+ out_track_info["track_visibility"] = track_visibility
+ return io.NodeOutput(out_track_info, track_length)
+
+
+class WanMoveConcatTrack(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="WanMoveConcatTrack",
+ category="conditioning/video_models",
+ inputs=[
+ io.Tracks.Input("tracks_1"),
+ io.Tracks.Input("tracks_2", optional=True),
+ ],
+ outputs=[
+ io.Tracks.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
+ if tracks_2 is None:
+ return io.NodeOutput(tracks_1)
+
+ tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
+ mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
+
+ out_track_info = {}
+ out_track_info["track_path"] = tracks_out
+ out_track_info["track_visibility"] = mask_out
+ return io.NodeOutput(out_track_info)
+
+
+class WanMoveTrackToVideo(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="WanMoveTrackToVideo",
+ category="conditioning/video_models",
+ inputs=[
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Vae.Input("vae"),
+ io.Tracks.Input("tracks", optional=True),
+ io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
+ io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
+ io.Image.Input("start_image"),
+ io.ClipVisionOutput.Input("clip_vision_output", optional=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(display_name="positive"),
+ io.Conditioning.Output(display_name="negative"),
+ io.Latent.Output(display_name="latent"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
+ device=comfy.model_management.intermediate_device()
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
+ if start_image is not None:
+ start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+ image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
+ image[:start_image.shape[0]] = start_image
+
+ concat_latent_image = vae.encode(image[:, :, :, :3])
+ mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
+
+ if tracks is not None and strength > 0.0:
+ tracks_path = tracks["track_path"][:length] # [T, N, 2]
+ num_tracks = tracks_path.shape[-2]
+
+ track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
+
+ track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
+ track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
+ concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
+ else:
+ concat_latent_image_pos = concat_latent_image
+
+ positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
+ negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
+
+ if clip_vision_output is not None:
+ positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
+ negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
+
+ out_latent = {}
+ out_latent["samples"] = latent
+ return io.NodeOutput(positive, negative, out_latent)
+
+
+class WanMoveExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ WanMoveTrackToVideo,
+ WanMoveTracksFromCoords,
+ WanMoveConcatTrack,
+ WanMoveVisualizeTracks,
+ GenerateTracks,
+ ]
+
+async def comfy_entrypoint() -> WanMoveExtension:
+ return WanMoveExtension()
diff --git a/nodes.py b/nodes.py
index 8d28a725d..8678f510a 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_logic.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
+ "nodes_wanmove.py",
]
import_failed = []
From 5495589db38409353a85b06df7d10f8de2f9c78d Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 11 Dec 2025 20:32:27 -0800
Subject: [PATCH 75/81] Respect the dtype the op was initialized in for non
quant mixed op. (#11282)
---
comfy/ops.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 6f34d50fc..6ae6e791a 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -497,8 +497,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
- self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
- # self.factory_kwargs = {"device": device, "dtype": dtype}
+ if dtype is None:
+ dtype = MixedPrecisionOps._compute_dtype
+
+ self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
@@ -530,7 +532,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
- self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
+ dtype = self.factory_kwargs["dtype"]
+ self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
+ if dtype != MixedPrecisionOps._compute_dtype:
+ self.comfy_cast_weights = True
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
From 908fd7d7496f6de88722263e1e00fcd3d22e584f Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Fri, 12 Dec 2025 10:18:31 +0200
Subject: [PATCH 76/81] feat(api-nodes): new TextToVideoWithAudio and
ImageToVideoWithAudio nodes (#11267)
---
comfy_api_nodes/apis/kling_api.py | 28 ++++-
comfy_api_nodes/nodes_kling.py | 169 ++++++++++++++++++++++++++----
2 files changed, 174 insertions(+), 23 deletions(-)
diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py
index d8949f8ac..80a758466 100644
--- a/comfy_api_nodes/apis/kling_api.py
+++ b/comfy_api_nodes/apis/kling_api.py
@@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image")
-class OmniTaskStatusResults(BaseModel):
+class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
-class OmniTaskStatusResponseData(BaseModel):
+class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
- task_result: OmniTaskStatusResults | None = Field(None)
+ task_result: TaskStatusResults | None = Field(None)
-class OmniTaskStatusResponse(BaseModel):
+class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
- data: OmniTaskStatusResponseData | None = Field(None)
+ data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
@@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
+
+
+class TextToVideoWithAudioRequest(BaseModel):
+ model_name: str = Field(..., description="kling-v2-6")
+ aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+ sound: str = Field(..., description="'on' or 'off'")
+
+
+class ImageToVideoWithAudioRequest(BaseModel):
+ model_name: str = Field(..., description="kling-v2-6")
+ image: str = Field(...)
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+ sound: str = Field(..., description="'on' or 'off'")
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index a2cc87d84..e545fe490 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
+ ImageToVideoWithAudioRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
- OmniTaskStatusResponse,
+ TaskStatusResponse,
+ TextToVideoWithAudioRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@@ -242,7 +244,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt)
-async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
+async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@@ -250,7 +252,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
@@ -483,12 +485,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id
final_response = await poll_op(
- cls,
- ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
- response_model=KlingImage2VideoResponse,
- estimated_duration=AVERAGE_DURATION_I2V,
- status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
- )
+ cls,
+ ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
+ response_model=KlingImage2VideoResponse,
+ estimated_duration=AVERAGE_DURATION_I2V,
+ status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
+ )
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -834,7 +836,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -929,7 +931,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@@ -997,7 +999,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1081,7 +1083,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1162,7 +1164,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@@ -1237,7 +1239,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
@@ -1253,7 +1255,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
- response_model=OmniTaskStatusResponse,
+ response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@@ -1328,9 +1330,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
- display_name="Kling Image to Video",
+ display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
- description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@@ -2034,6 +2035,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
+class TextToVideoWithAudio(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingTextToVideoWithAudio",
+ display_name="Kling Text to Video with Audio",
+ category="api node/video/Kling",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-v2-6"]),
+ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
+ IO.Combo.Input("mode", options=["pro"]),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Combo.Input("duration", options=[5, 10]),
+ IO.Boolean.Input("generate_audio", default=True),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ mode: str,
+ aspect_ratio: str,
+ duration: int,
+ generate_audio: bool,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=TextToVideoWithAudioRequest(
+ model_name=model_name,
+ prompt=prompt,
+ mode=mode,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
+class ImageToVideoWithAudio(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingImageToVideoWithAudio",
+ display_name="Kling Image(First Frame) to Video with Audio",
+ category="api node/video/Kling",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-v2-6"]),
+ IO.Image.Input("start_frame"),
+ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
+ IO.Combo.Input("mode", options=["pro"]),
+ IO.Combo.Input("duration", options=[5, 10]),
+ IO.Boolean.Input("generate_audio", default=True),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ start_frame: Input.Image,
+ prompt: str,
+ mode: str,
+ duration: int,
+ generate_audio: bool,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_image_dimensions(start_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model_name,
+ image=(await upload_images_to_comfyapi(cls, start_frame))[0],
+ prompt=prompt,
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -2057,6 +2188,8 @@ class KlingExtension(ComfyExtension):
OmniProVideoToVideoNode,
OmniProEditVideoNode,
OmniProImageNode,
+ TextToVideoWithAudio,
+ ImageToVideoWithAudio,
]
From c5a47a16924e1be96241553a1448b298e57e50a1 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 12 Dec 2025 08:49:35 -0800
Subject: [PATCH 77/81] Fix bias dtype issue in mixed ops. (#11293)
---
comfy/ops.py | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 6ae6e791a..0384c8717 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -504,10 +504,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.in_features = in_features
self.out_features = out_features
- if bias:
- self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
- else:
- self.register_parameter("bias", None)
+ self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@@ -536,6 +533,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
+ if self._has_bias:
+ self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
+ else:
+ self.register_parameter("bias", None)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
@@ -565,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False
)
+ if self._has_bias:
+ self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
+ else:
+ self.register_parameter("bias", None)
+
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
From da2bfb5b0af26c7a1c44ec951dbd0fffe413c793 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 12 Dec 2025 22:39:11 -0800
Subject: [PATCH 78/81] Basic implementation of z image fun control union 2.0
(#11304)
The inpaint part is currently missing and will be implemented later.
I think they messed up this model pretty bad. They added some
control_noise_refiner blocks but don't actually use them. There is a typo
in their code so instead of doing control_noise_refiner -> control_layers
it runs the whole control_layers twice.
Unfortunately they trained with this typo so the model works but is kind
of slow and would probably perform a lot better if they corrected their
code and trained it again.
---
comfy/ldm/lumina/controlnet.py | 95 +++++++++++++++++++++++--------
comfy/ldm/lumina/model.py | 16 +++++-
comfy/model_patcher.py | 3 +
comfy_extras/nodes_model_patch.py | 72 +++++++++++++++++------
4 files changed, 142 insertions(+), 44 deletions(-)
diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py
index fd7ce3b5c..8e2de7977 100644
--- a/comfy/ldm/lumina/controlnet.py
+++ b/comfy/ldm/lumina/controlnet.py
@@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
+ n_control_layers=6,
+ control_in_dim=16,
+ additional_in_dim=0,
+ broken=False,
+ refiner_control=False,
dtype=None,
device=None,
operations=None,
@@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
- self.additional_in_dim = 0
- self.control_in_dim = 16
+ self.broken = broken
+ self.additional_in_dim = additional_in_dim
+ self.control_in_dim = control_in_dim
n_refiner_layers = 2
- self.n_control_layers = 6
+ self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
@@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
- x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
+ x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+ self.refiner_control = refiner_control
+
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
- self.control_noise_refiner = nn.ModuleList(
- [
- JointTransformerBlock(
- layer_id,
- dim,
- n_heads,
- n_kv_heads,
- multiple_of,
- ffn_dim_multiplier,
- norm_eps,
- qk_norm,
- modulation=True,
- z_image_modulation=True,
- operation_settings=operation_settings,
- )
- for layer_id in range(n_refiner_layers)
- ]
- )
+ if self.refiner_control:
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=layer_id,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+ else:
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ z_image_modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
@@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
- for layer in self.control_noise_refiner:
- control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+ if not self.refiner_control:
+ for layer in self.control_noise_refiner:
+ control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+
return control_context
+ def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ if self.refiner_control:
+ if self.broken:
+ if layer_id == 0:
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ if layer_id > 0:
+ out = None
+ for i in range(1, len(self.control_layers)):
+ o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ if out is None:
+ out = o
+
+ return (out, control_context)
+ else:
+ return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ else:
+ return (None, control_context)
+
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index c47df49ca..96cb37fa6 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -536,6 +536,7 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
+ orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@@ -572,13 +573,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
+ patches = transformer_options.get("patches", {})
+
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
- for layer in self.noise_refiner:
+ x_input = x
+ for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
+ if "noise_refiner" in patches:
+ for p in patches["noise_refiner"]:
+ out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
+ if "img" in out:
+ x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
@@ -622,14 +631,15 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
- img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
+ img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
+ img_input = img
for i, layer in enumerate(self.layers):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
- out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
+ out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index a486c2723..93d26c690 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -454,6 +454,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
+ def set_model_noise_refiner_patch(self, patch):
+ self.set_model_patch(patch, "noise_refiner")
+
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py
index c61810dbf..ec0e790dc 100644
--- a/comfy_extras/nodes_model_patch.py
+++ b/comfy_extras/nodes_model_patch.py
@@ -243,7 +243,13 @@ class ModelPatchLoader:
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
- model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
+ config = {}
+ if 'control_layers.14.adaLN_modulation.0.weight' in sd:
+ config['n_control_layers'] = 15
+ config['additional_in_dim'] = 17
+ config['refiner_control'] = True
+ config['broken'] = True
+ model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -297,56 +303,86 @@ class DiffSynthCnetPatch:
return [self.model_patch]
class ZImageControlPatch:
- def __init__(self, model_patch, vae, image, strength):
+ def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
+ self.inpaint_image = inpaint_image
+ self.mask = mask
self.strength = strength
self.encoded_image = self.encode_latent_cond(image)
self.encoded_image_size = (image.shape[1], image.shape[2])
self.temp_data = None
- def encode_latent_cond(self, image):
- latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
- return latent_image
+ def encode_latent_cond(self, control_image, inpaint_image=None):
+ latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
+ if self.model_patch.model.additional_in_dim > 0:
+ if self.mask is None:
+ mask_ = torch.zeros_like(latent_image)[:, :1]
+ else:
+ mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
+ if inpaint_image is None:
+ inpaint_image = torch.ones_like(control_image) * 0.5
+
+ inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
+
+ return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
+ else:
+ return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
+ img_input = kwargs.get("img_input")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
+ block_type = kwargs.get("block_type", "")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
+ inpaint_scaled = None
+ if self.inpaint_image is not None:
+ inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
- self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
+ self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models)
- cnet_index = (block_index // 5)
- cnet_index_float = (block_index / 5)
+ cnet_blocks = self.model_patch.model.n_control_layers
+ div = round(30 / cnet_blocks)
+
+ cnet_index = (block_index // div)
+ cnet_index_float = (block_index / div)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
- cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
- self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
+ if block_type == "noise_refiner":
+ self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
+ else:
+ self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
- while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
+ if block_type == "noise_refiner":
next_layer = self.temp_data[0] + 1
- self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
+ self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
+ if self.temp_data[1][0] is not None:
+ img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
+ else:
+ while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
+ next_layer = self.temp_data[0] + 1
+ self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
- if cnet_index_float == self.temp_data[0]:
- img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
- if cnet_blocks == self.temp_data[0] + 1:
- self.temp_data = None
+ if cnet_index_float == self.temp_data[0]:
+ img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
+ if cnet_blocks == self.temp_data[0] + 1:
+ self.temp_data = None
return kwargs
@@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet:
mask = 1.0 - mask
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
- model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
+ patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
+ model_patched.set_model_noise_refiner_patch(patch)
+ model_patched.set_model_double_block_patch(patch)
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
From 971cefe7d4ca15c949d5d901a663cb66562a4f10 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 13 Dec 2025 15:45:23 -0800
Subject: [PATCH 79/81] Fix pytorch warnings. (#11314)
---
comfy/ops.py | 2 +-
comfy/utils.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 0384c8717..16889bb82 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
- sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
+ sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):
diff --git a/comfy/utils.py b/comfy/utils.py
index 9dc0d76ac..3866cda2e 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
- state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
+ state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata
From 6592bffc609da4738b111dbffca1f473972f3574 Mon Sep 17 00:00:00 2001
From: chaObserv <154517000+chaObserv@users.noreply.github.com>
Date: Sun, 14 Dec 2025 13:03:29 +0800
Subject: [PATCH 80/81] seeds_2: add phi_2 variant and sampler node (#11309)
* Add phi_2 solver type to seeds_2
* Add sampler node of seeds_2
---
comfy/k_diffusion/sampling.py | 15 ++++++++++++---
comfy_extras/nodes_custom_sampler.py | 26 ++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 3 deletions(-)
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 0e2cda291..753c66afa 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
-def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
+def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
+ if solver_type not in {"phi_1", "phi_2"}:
+ raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
+
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
- denoised_d = torch.lerp(denoised, denoised_2, fac)
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
+ if solver_type == "phi_1":
+ denoised_d = torch.lerp(denoised, denoised_2, fac)
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
+ elif solver_type == "phi_2":
+ b2 = ei_h_phi_2(-h_eta) / r
+ b1 = ei_h_phi_1(-h_eta) - b2
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
+
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index fbb080886..71ea4e9ec 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
+class SamplerSEEDS2(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerSEEDS2",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
+ io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
+
+ @classmethod
+ def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
+ sampler_name = "seeds_2"
+ sampler = comfy.samplers.ksampler(
+ sampler_name,
+ {"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
+ )
+ return io.NodeOutput(sampler)
+
+
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
+ SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,
From 5ac3b26a7dedb9b13c681abe8733c54f13353273 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sun, 14 Dec 2025 01:02:50 -0800
Subject: [PATCH 81/81] Update warning for old pytorch version. (#11319)
Versions below 2.4 are no longer supported. We will not break support on purpose but will not fix it if we do.
---
comfy/utils.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/utils.py b/comfy/utils.py
index 3866cda2e..8d4e2b445 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
- logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
+ logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None: