From 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 7 Dec 2025 04:44:55 -0800 Subject: [PATCH 001/148] Properly load the newbie diffusion model. (#11172) There is still one of the text encoders missing and I didn't actually test it. --- comfy/ldm/lumina/model.py | 35 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 4 ++++ comfy/model_detection.py | 3 +++ 3 files changed, 42 insertions(+) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 6c24fed9b..c47df49ca 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -377,6 +377,7 @@ class NextDiT(nn.Module): z_image_modulation=False, time_scale=1.0, pad_tokens_multiple=None, + clip_text_dim=None, image_model=None, device=None, dtype=None, @@ -447,6 +448,31 @@ class NextDiT(nn.Module): ), ) + self.clip_text_pooled_proj = None + + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_text_embed = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024) + clip_text_dim, + min(dim, 1024), + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.layers = nn.ModuleList( [ JointTransformerBlock( @@ -585,6 +611,15 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0be006cc2..6b8a8454d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,6 +1110,10 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + clip_text_pooled = kwargs["pooled_output"] # Newbie + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30b33a486..74c547427 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 + ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) + if ctd_weight is not None: + dit_config["clip_text_dim"] = ctd_weight.shape[0] elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 From ec7f65187d85e22ea23345ce0d919e11768f255e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:21:41 +0200 Subject: [PATCH 002/148] chore(comfy_api): replace absolute imports with relative (#11145) --- comfy_api/latest/__init__.py | 8 ++++---- comfy_api/latest/_input/video_types.py | 2 +- comfy_api/latest/_input_impl/video_types.py | 4 ++-- comfy_api/latest/_io.py | 2 +- comfy_api/latest/_ui.py | 2 +- comfy_api/latest/_util/video_types.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 0fa01d1e7..35e1ac853 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput -from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents -from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input_impl import VideoFromFile, VideoFromComponents +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 @@ -80,7 +80,7 @@ class ComfyExtension(ABC): async def on_load(self) -> None: """ Called when an extension is loaded. - This should be used to initialize any global resources neeeded by the extension. + This should be used to initialize any global resources needed by the extension. """ @abstractmethod diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 87c81d73a..e634a0311 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -4,7 +4,7 @@ from fractions import Fraction from typing import Optional, Union, IO import io import av -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): """ diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a4cd3737d..ea35c6062 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -3,14 +3,14 @@ from av.container import InputContainer from av.subtitles.stream import SubtitleStream from fractions import Fraction from typing import Optional -from comfy_api.latest._input import AudioInput, VideoInput +from .._input import AudioInput, VideoInput import av import io import json import numpy as np import math import torch -from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents def container_to_output_format(container_format: str | None) -> str | None: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index d7cbe68cf..313a5af20 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from comfy_api.latest._resources import Resources, ResourcesLocal +from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker from ._util import MESH, VOXEL diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 5a75a3aae..2babe209a 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -22,7 +22,7 @@ import folder_paths # used for image preview from comfy.cli_args import args -from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput +from ._io import ComfyNode, FolderType, Image, _UIOutput class SavedResult(dict): diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index c3e3d8e3a..fd3b5a510 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from comfy_api.latest._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput class VideoCodec(str, Enum): AUTO = "auto" From 058f084371ef2ed0c456118dfdd3d0bfed17259b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 8 Dec 2025 17:22:51 +0800 Subject: [PATCH 003/148] Update workflow templates to v0.7.51 (#11150) * chore: update workflow templates to v0.7.50 * Update template to 0.7.51 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f98848e20..12a7c1089 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.33.10 -comfyui-workflow-templates==0.7.25 +comfyui-workflow-templates==0.7.51 comfyui-embedded-docs==0.3.1 torch torchsde From 85c4b4ae262c2de360891dd23c6504da2f5a6014 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:27:02 +0200 Subject: [PATCH 004/148] chore: replace imports of deprecated V1 classes (#11127) --- comfy_api_nodes/apis/veo_api.py | 2 +- comfy_api_nodes/nodes_gemini.py | 19 ++++++++++--------- comfy_api_nodes/nodes_ltxv.py | 17 +++++++---------- comfy_api_nodes/nodes_moonvalley.py | 19 ++++++++----------- comfy_api_nodes/nodes_runway.py | 29 +++++++++++++---------------- comfy_api_nodes/nodes_veo2.py | 12 +++++------- comfy_extras/nodes_video.py | 27 +++++++++++---------------- 7 files changed, 55 insertions(+), 70 deletions(-) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index 8328d1aa4..23ca725b7 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -85,7 +85,7 @@ class Response1(BaseModel): raiMediaFilteredReasons: Optional[list[str]] = Field( None, description='Reasons why media was filtered by responsible AI policies' ) - videos: Optional[list[Video]] = None + videos: Optional[list[Video]] = Field(None) class VeoGenVidPollResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 08f7b0f64..0b7422ef7 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -13,8 +13,7 @@ import torch from typing_extensions import override import folder_paths -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.gemini_api import ( GeminiContent, GeminiFileData, @@ -68,7 +67,7 @@ class GeminiImageModel(str, Enum): async def create_image_parts( cls: type[IO.ComfyNode], - images: torch.Tensor, + images: Input.Image, image_limit: int = 0, ) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] @@ -154,8 +153,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: - image_tensors: list[torch.Tensor] = [] +def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: + image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: image_data = base64.b64decode(part.inlineData.data) @@ -293,7 +292,9 @@ class GeminiNode(IO.ComfyNode): def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: """Convert video input to Gemini API compatible parts.""" - base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) return [ GeminiPart( inlineData=GeminiInlineData( @@ -343,7 +344,7 @@ class GeminiNode(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, audio: Input.Audio | None = None, video: Input.Video | None = None, files: list[GeminiPart] | None = None, @@ -542,7 +543,7 @@ class GeminiImage(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", response_modalities: str = "IMAGE+TEXT", @@ -662,7 +663,7 @@ class GeminiImage2(IO.ComfyNode): aspect_ratio: str, resolution: str, response_modalities: str, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0b757a62b..7e61560dc 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -1,12 +1,9 @@ from io import BytesIO -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.util import ( ApiEndpoint, get_number_of_images, @@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel): model: str = Field(...) duration: int = Field(...) resolution: str = Field(...) - fps: Optional[int] = Field(25) - generate_audio: Optional[bool] = Field(True) - image_uri: Optional[str] = Field(None) + fps: int | None = Field(25) + generate_audio: bool | None = Field(True) + image_uri: str | None = Field(None) class TextToVideoNode(IO.ComfyNode): @@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class ImageToVideoNode(IO.ComfyNode): @@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, model: str, prompt: str, duration: int, @@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class LtxvApiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7c31d95b3..2771e4790 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,11 +1,8 @@ import logging -from typing import Optional -import torch from typing_extensions import override -from comfy_api.input import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, @@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None: raise RuntimeError(error_msg) -def validate_video_to_video_input(video: VideoInput) -> VideoInput: +def validate_video_to_video_input(video: Input.Video) -> Input.Video: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput: return _validate_and_trim_duration(video) -def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: +def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: """Extracts video dimensions with error handling.""" try: return video.get_dimensions() @@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None: raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def _validate_and_trim_duration(video: VideoInput) -> VideoInput: +def _validate_and_trim_duration(video: Input.Video) -> Input.Video: """Validates video duration and trims to 5 seconds if needed.""" duration = video.get_duration() _validate_minimum_duration(duration) @@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None: raise ValueError("Input video must be at least 5 seconds long.") -def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: +def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: """Trims video to 5 seconds if longer.""" if duration > 5: return trim_video(video, 5) @@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str, resolution: str, @@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): prompt: str, negative_prompt: str, seed: int, - video: Optional[VideoInput] = None, + video: Input.Video | None = None, control_type: str = "Motion Transfer", - motion_intensity: Optional[int] = 100, + motion_intensity: int | None = 100, steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 2fdafbbfe..3c55039c9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,12 +11,11 @@ User Guides: """ -from typing import Union, Optional -from typing_extensions import override from enum import Enum -import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, @@ -44,8 +43,6 @@ from comfy_api_nodes.util import ( sync_op, poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum): field_1280_768 = "1280:768" -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the video URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N def extract_progress_from_task_status( response: TaskStatusResponse, -) -> Union[float, None]: +) -> float | None: if hasattr(response, "progress") and response.progress is not None: return response.progress * 100 return None -def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" return await poll_op( @@ -119,8 +116,8 @@ async def get_response( async def generate_video( cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - estimated_duration: Optional[int] = None, -) -> VideoFromFile: + estimated_duration: int | None = None, +) -> InputImpl.VideoFromFile: initial_response = await sync_op( cls, endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), @@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, - end_frame: torch.Tensor, + start_frame: Input.Image, + end_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode): cls, prompt: str, ratio: str, - reference_image: Optional[torch.Tensor] = None, + reference_image: Input.Image | None = None, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index a54dc13ab..e165b8380 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,11 +1,9 @@ import base64 from io import BytesIO -import torch from typing_extensions import override -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, @@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Check if video is provided as base64 or URL if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if hasattr(video, "gcsUri") and video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) @@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): aspect_ratio: str, duration: int, seed: int, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, model: str, generate_audio: bool, ): @@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): if response.videos: video = response.videos[0] if video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) raise Exception("Video returned but no data or URL was provided") diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 6cf6e39bf..c609e03da 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -8,10 +8,7 @@ import json from typing import Optional from typing_extensions import override from fractions import Fraction -from comfy_api.input import AudioInput, ImageInput, VideoInput -from comfy_api.input_impl import VideoFromComponents, VideoFromFile -from comfy_api.util import VideoCodec, VideoComponents, VideoContainer -from comfy_api.latest import ComfyExtension, io, ui +from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args class SaveWEBM(io.ComfyNode): @@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode): io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode): inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), - io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), - io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput: + def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode): metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=VideoContainer(format), + format=Types.VideoContainer(format), codec=codec, metadata=saved_metadata ) @@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: return io.NodeOutput( - VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) ) class GetVideoComponents(io.ComfyNode): @@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput) -> io.NodeOutput: + def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) + class LoadVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode): @classmethod def execute(cls, file) -> io.NodeOutput: video_path = folder_paths.get_annotated_filepath(file) - return io.NodeOutput(VideoFromFile(video_path)) + return io.NodeOutput(InputImpl.VideoFromFile(video_path)) @classmethod def fingerprint_inputs(s, file): From c3c6313fc7b24a5811efde7cfe10b7cbbea52663 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:28:17 +0200 Subject: [PATCH 005/148] Added "system_prompt" input to Gemini nodes (#11177) --- comfy_api_nodes/apis/gemini_api.py | 10 +----- comfy_api_nodes/nodes_gemini.py | 52 ++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index a380ecc86..f8edc38c9 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel): description="A list of ordered parts that make up a single message. " "Different parts may have different IANA MIME types.", ) - role: GeminiRole = Field( - ..., - description="The identity of the entity that creates the message. " - "The following values are supported: " - "user: This indicates that the message is sent by a real person, typically a user-generated message. " - "model: This indicates that the message is generated by the model. " - "The model value is used to insert messages from model into the conversation during multi-turn conversations. " - "For non-multi-turn conversations, this field can be left blank or unset.", - ) + role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.") class GeminiFunctionDeclaration(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 0b7422ef7..ad0f4b4d1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -26,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import ( GeminiMimeType, GeminiPart, GeminiRole, + GeminiSystemInstructionContent, + GeminiTextPart, Modality, ) from comfy_api_nodes.util import ( @@ -42,6 +44,14 @@ from comfy_api_nodes.util import ( GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_IMAGE_SYS_PROMPT = ( + "You are an expert image-generation engine. You must ALWAYS produce an image.\n" + "Interpret all user input—regardless of " + "format, intent, or abstraction—as literal visual directives for image composition.\n" + "If a prompt is conversational or lacks specific visual details, " + "you must creatively invent a concrete visual scenario that depicts the concept.\n" + "Prioritize generating the visual representation above any text, formatting, or conversational requests." +) class GeminiModel(str, Enum): @@ -276,6 +286,13 @@ class GeminiNode(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.String.Output(), @@ -348,6 +365,7 @@ class GeminiNode(IO.ComfyNode): audio: Input.Audio | None = None, video: Input.Video | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -364,7 +382,10 @@ class GeminiNode(IO.ComfyNode): if files is not None: parts.extend(files) - # Create response + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -374,7 +395,8 @@ class GeminiNode(IO.ComfyNode): role=GeminiRole.user, parts=parts, ) - ] + ], + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -524,6 +546,13 @@ class GeminiImage(IO.ComfyNode): "'IMAGE+TEXT' to return both the generated image and a text response.", optional=True, ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -547,6 +576,7 @@ class GeminiImage(IO.ComfyNode): files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", response_modalities: str = "IMAGE+TEXT", + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) parts: list[GeminiPart] = [GeminiPart(text=prompt)] @@ -560,6 +590,10 @@ class GeminiImage(IO.ComfyNode): if files is not None: parts.extend(files) + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -571,6 +605,7 @@ class GeminiImage(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -641,6 +676,13 @@ class GeminiImage2(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -665,6 +707,7 @@ class GeminiImage2(IO.ComfyNode): response_modalities: str, images: Input.Image | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -680,6 +723,10 @@ class GeminiImage2(IO.ComfyNode): if aspect_ratio != "auto": image_config.aspectRatio = aspect_ratio + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -691,6 +738,7 @@ class GeminiImage2(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, From fd271dedfde6e192a1f1a025521070876e89e04a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:33:46 +0200 Subject: [PATCH 006/148] [API Nodes] add support for seedance-1-0-pro-fast model (#10947) * feat(api-nodes): add support for seedance-1-0-pro-fast model * feat(api-nodes): add support for seedream-4.5 model --- comfy_api_nodes/apis/bytedance_api.py | 144 +++++++++++++++ comfy_api_nodes/nodes_bytedance.py | 255 ++++++-------------------- 2 files changed, 196 insertions(+), 203 deletions(-) create mode 100644 comfy_api_nodes/apis/bytedance_api.py diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py new file mode 100644 index 000000000..77cd76f9b --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -0,0 +1,144 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + size: str | None = Field(None) + seed: int | None = Field(0, ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: str | None = Field("adaptive") + seed: int | None = Field(..., ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str = Field("url") + image: list[str] | None = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: TaskStatusError | None = Field(None) + content: TaskStatusResult | None = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, + "seedance-1-0-pro-fast-251015": { + "480p": 50, + "720p": 65, + "1080p": 100, + }, +} diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index caced471e..57c0218d0 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,13 +1,27 @@ import logging import math -from enum import Enum -from typing import Literal, Optional, Union import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_api import ( + RECOMMENDED_PRESETS, + RECOMMENDED_PRESETS_SEEDREAM_4, + VIDEO_TASKS_EXECUTION_TIME, + Image2ImageTaskCreationRequest, + Image2VideoTaskCreationRequest, + ImageTaskCreationResponse, + Seedream4Options, + Seedream4TaskCreationRequest, + TaskCreationResponse, + TaskImageContent, + TaskImageContentUrl, + TaskStatusResponse, + TaskTextContent, + Text2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} -class Text2ImageModelName(str, Enum): - seedream_3 = "seedream-3-0-t2i-250415" - - -class Image2ImageModelName(str, Enum): - seededit_3 = "seededit-3-0-i2i-250628" - - -class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-t2v-250428" - - -class Image2VideoModelName(str, Enum): - """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-i2v-250428" - - -class Text2ImageTaskCreationRequest(BaseModel): - model: Text2ImageModelName = Text2ImageModelName.seedream_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - size: Optional[str] = Field(None) - seed: Optional[int] = Field(0, ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Image2ImageTaskCreationRequest(BaseModel): - model: Image2ImageModelName = Image2ImageModelName.seededit_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: Optional[str] = Field("adaptive") - seed: Optional[int] = Field(..., ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Seedream4Options(BaseModel): - max_images: int = Field(15) - - -class Seedream4TaskCreationRequest(BaseModel): - model: str = Field("seedream-4-0-250828") - prompt: str = Field(...) - response_format: str = Field("url") - image: Optional[list[str]] = Field(None, description="Image URLs") - size: str = Field(...) - seed: int = Field(..., ge=0, le=2147483647) - sequential_image_generation: str = Field("disabled") - sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) - - -class ImageTaskCreationResponse(BaseModel): - model: str = Field(...) - created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") - data: list = Field([], description="Contains information about the generated image(s).") - error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") - - -class TaskTextContent(BaseModel): - type: str = Field("text") - text: str = Field(...) - - -class TaskImageContentUrl(BaseModel): - url: str = Field(...) - - -class TaskImageContent(BaseModel): - type: str = Field("image_url") - image_url: TaskImageContentUrl = Field(...) - role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) - - -class Text2VideoTaskCreationRequest(BaseModel): - model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro - content: list[TaskTextContent] = Field(..., min_length=1) - - -class Image2VideoTaskCreationRequest(BaseModel): - model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro - content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) - - -class TaskCreationResponse(BaseModel): - id: str = Field(...) - - -class TaskStatusError(BaseModel): - code: str = Field(...) - message: str = Field(...) - - -class TaskStatusResult(BaseModel): - video_url: str = Field(...) - - -class TaskStatusResponse(BaseModel): - id: str = Field(...) - model: str = Field(...) - status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) - error: Optional[TaskStatusError] = Field(None) - content: Optional[TaskStatusResult] = Field(None) - - -RECOMMENDED_PRESETS = [ - ("1024x1024 (1:1)", 1024, 1024), - ("864x1152 (3:4)", 864, 1152), - ("1152x864 (4:3)", 1152, 864), - ("1280x720 (16:9)", 1280, 720), - ("720x1280 (9:16)", 720, 1280), - ("832x1248 (2:3)", 832, 1248), - ("1248x832 (3:2)", 1248, 832), - ("1512x648 (21:9)", 1512, 648), - ("2048x2048 (1:1)", 2048, 2048), - ("Custom", None, None), -] - -RECOMMENDED_PRESETS_SEEDREAM_4 = [ - ("2048x2048 (1:1)", 2048, 2048), - ("2304x1728 (4:3)", 2304, 1728), - ("1728x2304 (3:4)", 1728, 2304), - ("2560x1440 (16:9)", 2560, 1440), - ("1440x2560 (9:16)", 1440, 2560), - ("2496x1664 (3:2)", 2496, 1664), - ("1664x2496 (2:3)", 1664, 2496), - ("3024x1296 (21:9)", 3024, 1296), - ("4096x4096 (1:1)", 4096, 4096), - ("Custom", None, None), -] - -# The time in this dictionary are given for 10 seconds duration. -VIDEO_TASKS_EXECUTION_TIME = { - "seedance-1-0-lite-t2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-lite-i2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-pro-250528": { - "480p": 70, - "720p": 85, - "1080p": 115, - }, -} - - def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" @@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: return response.data[0]["url"] -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: - """Returns the video URL from the task status response if it exists.""" - if hasattr(response, "content") and response.content: - return response.content.video_url - return None - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode): category="api node/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Text2ImageModelName, - default=Text2ImageModelName.seedream_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), IO.String.Input( "prompt", multiline=True, @@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Image2ImageModelName, - default=Image2ImageModelName.seededit_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), IO.Image.Input( "image", tooltip="The base image to edit", @@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, seed: int, guidance_scale: float, @@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedream-4-0-250828"], + options=["seedream-4-5-251128", "seedream-4-0-250828"], tooltip="Model name", ), IO.String.Input( @@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor = None, + image: Input.Image | None = None, size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], width: int = 2048, height: int = 2048, @@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode): raise ValueError( f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) + out_num_pixels = w * h + mp_provided = out_num_pixels / 1_000_000.0 + if "seedream-4-5" in model and out_num_pixels < 3686400: + raise ValueError( + f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"but {mp_provided:.2f}MP provided." + ) + if "seedream-4-0" in model and out_num_pixels < 921600: + raise ValueError( + f"Minimum image resolution that the selected model can generate is 0.92MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") @@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Text2VideoModelName, - default=Text2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Image2VideoModelName, - default=Image2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor, + image: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[model.value for model in Image2VideoModelName], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): cls, model: str, prompt: str, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): cls, model: str, prompt: str, - images: torch.Tensor, + images: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): async def process_video_task( cls: type[IO.ComfyNode], - payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - estimated_duration: Optional[int], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, ) -> IO.NodeOutput: initial_response = await sync_op( cls, @@ -1085,7 +934,7 @@ async def process_video_task( estimated_duration=estimated_duration, response_model=TaskStatusResponse, ) - return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) def raise_if_text_params(prompt: str, text_params: list[str]) -> None: From 8e889c535d1fc407bf27dbf8359eef9580f2ed60 Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Mon, 8 Dec 2025 21:17:26 +0100 Subject: [PATCH 007/148] Support "transformer." LoRA prefix for Z-Image (#11135) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index e7202ce97..2ed0acb9d 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}): to = diffusers_keys[k] key_lora = k[:-len(".weight")] key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to if isinstance(model, comfy.model_base.Kandinsky5): From 60ee574748209a17ade1c7524e228be2802d1589 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 9 Dec 2025 06:18:06 +1000 Subject: [PATCH 008/148] retune lowVramPatch VRAM accounting (#11173) In the lowvram case, this now does its math in the model dtype in the post de-quantization domain. Account for that. The patching was also put back on the compute stream getting it off-peak so relax the MATH_FACTOR to only x2 so get out of the worst-case assumption of everything peaking at once. --- comfy/model_patcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5b1ccb824..8b5edeb52 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -132,14 +132,14 @@ class LowVramPatch: def __call__(self, weight): return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) -#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 -LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 def low_vram_patch_estimate_vram(model, key): weight, set_func, convert_func = get_key_weight(model, key) if weight is None: return 0 - return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): set_func = None From 935493f6c186de8808508713a465d6bda75e5ce4 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 9 Dec 2025 04:18:53 +0800 Subject: [PATCH 009/148] chore: update workflow templates to v0.7.54 (#11192) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 12a7c1089..4bd4b21c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.33.10 -comfyui-workflow-templates==0.7.51 +comfyui-workflow-templates==0.7.54 comfyui-embedded-docs==0.3.1 torch torchsde From 3b0368aa34182fc7c97de92d59b609c77138def2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:38:36 -0800 Subject: [PATCH 010/148] Fix regression. (#11194) --- comfy/model_patcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 8b5edeb52..a7d24ac13 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -139,6 +139,9 @@ def low_vram_patch_estimate_vram(model, key): if weight is None: return 0 model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + if model_dtype is None: + model_dtype = weight.dtype + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): From d50f342c90802830c1178ad9d7f2783dc2821af1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:20:04 -0800 Subject: [PATCH 011/148] Fix potential issue. (#11201) --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a7d24ac13..2e8ce2613 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -923,7 +923,7 @@ class ModelPatcher: patch_counter += 1 cast_weight = True - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False From e136b6dbb0b08341388f5bf9a00b1fca29992eb3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:21:31 +1000 Subject: [PATCH 012/148] dequantization offload accounting (fixes Flux2 OOMs - incl TEs) (#11171) * make setattr safe for non existent attributes Handle the case where the attribute doesnt exist by returning a static sentinel (distinct from None). If the sentinel is passed in as the set value, del the attr. * Account for dequantization and type-casts in offload costs When measuring the cost of offload, identify weights that need a type change or dequantization and add the size of the conversion result to the offload cost. This is mutually exclusive with lowvram patches which already has a large conservative estimate and wont overlap the dequant cost so\ dont double count. * Set the compute type on CLIP MPs So that the loader can know the size of weights for dequant accounting. --- comfy/model_patcher.py | 19 +++++++++++++------ comfy/sd.py | 2 ++ comfy/utils.py | 9 +++++++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2e8ce2613..a486c2723 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -665,12 +666,18 @@ class ModelPatcher: module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append((module_offload_mem, module_mem, n, m, params)) return loading diff --git a/comfy/sd.py b/comfy/sd.py index 754b1703d..a16f2d14f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -127,6 +127,8 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None diff --git a/comfy/utils.py b/comfy/utils.py index 89846bc95..9dc0d76ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +ATTR_UNSET={} + def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value): From cabc4d351ff620ece87f18019d98131ebcbdf1aa Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 8 Dec 2025 20:22:02 -0800 Subject: [PATCH 013/148] bump comfyui-frontend-package to 1.33.13 (patch) (#11200) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4bd4b21c3..11a7ac245 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.33.10 +comfyui-frontend-package==1.33.13 comfyui-workflow-templates==0.7.54 comfyui-embedded-docs==0.3.1 torch From b9fb542703085c58f082b4a822329fb6670e8016 Mon Sep 17 00:00:00 2001 From: Lodestone Date: Tue, 9 Dec 2025 11:33:29 +0700 Subject: [PATCH 014/148] add chroma-radiance-x0 mode (#11197) --- comfy/ldm/chroma_radiance/model.py | 20 ++++++++++++++++++-- comfy/model_detection.py | 2 ++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index e643b4414..70d173889 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams): nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] - + use_x0: bool class ChromaRadiance(Chroma): """ @@ -159,6 +159,9 @@ class ChromaRadiance(Chroma): self.skip_dit = [] self.lite = False + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -276,6 +279,12 @@ class ChromaRadiance(Chroma): params_dict |= overrides return params.__class__(**params_dict) + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def _forward( self, x: Tensor, @@ -316,4 +325,11 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c547427..19e6aa954 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys From 9d252f3b70c0e89cbb581e28bb1862593c4e5ceb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:55:13 +1000 Subject: [PATCH 015/148] ops: delete dead code (#11204) This became dead code in https://github.com/comfyanonymous/ComfyUI/pull/11069 --- comfy/ops.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 35237c9f7..6f34d50fc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,7 +22,6 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib import json def run_every_op(): @@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None - if offload_stream is not None: - wf_context = offload_stream - if hasattr(wf_context, "as_context"): - wf_context = wf_context.as_context(offload_stream) - else: - wf_context = contextlib.nullcontext() - non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 From e2a800e7ef225260c078ce484c75bb40161d9d94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 9 Dec 2025 23:59:16 +0200 Subject: [PATCH 016/148] Fix for HunyuanVideo1.5 meanflow distil (#11212) --- comfy/ldm/hunyuan_video/model.py | 3 ++- comfy/model_detection.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 2749c53f5..55ab550f8 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -43,6 +43,7 @@ class HunyuanVideoParams: meanflow: bool use_cond_type_embedding: bool vision_in_dim: int + meanflow_sum: bool class SelfAttentionRef(nn.Module): @@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module): timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) - vec = (vec + vec_r) / 2 + vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2 if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 19e6aa954..1f5d34bdd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_cond_type_embedding"] = False if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["meanflow_sum"] = True else: dit_config["vision_in_dim"] = None + dit_config["meanflow_sum"] = False return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) From 791e30ff5037fa5e7aa4e1396099ea8d6bfb020b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:03:21 -0800 Subject: [PATCH 017/148] Fix nan issue when quantizing fp16 tensor. (#11213) --- comfy/quant_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 571d3f760..cd96541d7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -399,7 +399,10 @@ class TensorCoreFP8Layout(QuantizedLayout): orig_dtype = tensor.dtype if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) if scale is not None: if not isinstance(scale, torch.Tensor): From fc657f471a29d07696ca16b566000e8e555d67d1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 Dec 2025 18:22:09 -0500 Subject: [PATCH 018/148] ComfyUI version v0.4.0 From now on ComfyUI will do version numbers a bit differently, every stable off the master branch will increment the minor version. Anytime a fix needs to be backported onto a stable version the patch version will be incremented. Example: We release v0.6.0 off the master branch then a day later a bug is discovered and we decide to backport the fix onto the v0.6.0 stable, this will be done in a separate branch in the main repository and this new stable will be tagged v0.6.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4b039356e..2f083edaf 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.76" +__version__ = "0.4.0" diff --git a/pyproject.toml b/pyproject.toml index 02b94a0ce..e4d3d616a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.76" +version = "0.4.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f668c2e3c99df40561b416cf62b0fd9eec96007a Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Tue, 9 Dec 2025 19:27:07 -0800 Subject: [PATCH 019/148] bump comfyui-frontend-package to 1.34.8 (#11220) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 11a7ac245..9e9b25328 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.33.13 +comfyui-frontend-package==1.34.8 comfyui-workflow-templates==0.7.54 comfyui-embedded-docs==0.3.1 torch From 36357bbcc3c515e37a742457a2b2ab4b7ccc17a8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:55:09 +0200 Subject: [PATCH 020/148] process the NodeV1 dict results correctly (#11237) --- comfy_api/latest/_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 313a5af20..79217c813 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1815,7 +1815,7 @@ class NodeOutput(_NodeOutputInternal): ui = data["ui"] if "expand" in data: expand = data["expand"] - return cls(args=args, ui=ui, expand=expand) + return cls(*args, ui=ui, expand=expand) def __getitem__(self, index) -> Any: return self.args[index] From 17c92a9f2843d7b9b727531066be2378b350a6ae Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 10 Dec 2025 16:59:48 -0800 Subject: [PATCH 021/148] Tweak Z Image memory estimation. (#11254) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 383c82c3e..dd0f09f32 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1026,7 +1026,7 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 1.7 + memory_usage_factor = 2.0 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] From 57ddb7fd13d817e7259c2c992a852832b6b0f07a Mon Sep 17 00:00:00 2001 From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com> Date: Thu, 11 Dec 2025 03:49:49 +0100 Subject: [PATCH 022/148] Fix: filter hidden files from /internal/files endpoint (#11191) --- api_server/routes/internal/internal_routes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 613b0f7c7..b224306da 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -58,8 +58,13 @@ class InternalRoutes: return web.json_response({"error": "Invalid directory type"}, status=400) directory = get_directory_by_type(directory_type) + + def is_visible_file(entry: os.DirEntry) -> bool: + """Filter out hidden files (e.g., .DS_Store on macOS).""" + return entry.is_file() and not entry.name.startswith('.') + sorted_files = sorted( - (entry for entry in os.scandir(directory) if entry.is_file()), + (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) return web.json_response([entry.name for entry in sorted_files], status=200) From e711aaf1a75120195c56ebd1f1ce829c6b7b84db Mon Sep 17 00:00:00 2001 From: Farshore <168402472+jiangchengchengark@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:02:26 +0800 Subject: [PATCH 023/148] =?UTF-8?q?Lower=20VAE=20loading=20requirements?= =?UTF-8?q?=EF=BC=9ACreate=20a=20new=20branch=20for=20GPU=20memory=20calcu?= =?UTF-8?q?lations=20in=20qwen-image=20vae=20(#11199)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index a16f2d14f..1cad98aef 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -549,8 +549,10 @@ class VAE: ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] - self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype) + + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: From 93948e3fc598c14082f744fe82fae056b64ff481 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 11 Dec 2025 08:11:12 +0200 Subject: [PATCH 024/148] feat(api-nodes): enable Kling Omni O1 node (#11229) --- comfy_api_nodes/nodes_kling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 6c840dc47..a2cc87d84 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2056,7 +2056,7 @@ class KlingExtension(ComfyExtension): OmniProImageToVideoNode, OmniProVideoToVideoNode, OmniProEditVideoNode, - # OmniProImageNode, # need support from backend + OmniProImageNode, ] From f8321eb57b29a4b34cecd27d5d6365adf5e6e601 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:30:31 -0800 Subject: [PATCH 025/148] Adjust memory usage factor. (#11257) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index dd0f09f32..ef8c75c09 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -541,7 +541,7 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - memory_usage_factor = 1.2 + memory_usage_factor = 1.6 text_encoder_key_prefix = ["text_encoders."] From fdebe182966d1dd9bee3138264937137bd2302d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:09:35 -0800 Subject: [PATCH 026/148] Fix regular chroma radiance (#11276) --- comfy/model_detection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1f5d34bdd..94b54b7c2 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -261,6 +261,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 if "__x0__" in state_dict_keys: # x0 pred dit_config["use_x0"] = True + else: + dit_config["use_x0"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys From ae65433a602470eea271df47af0eb871d146a002 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:15:00 -0800 Subject: [PATCH 027/148] This only works on radiance. (#11277) --- comfy/model_detection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 94b54b7c2..dd6a703f6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -259,10 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 - if "__x0__" in state_dict_keys: # x0 pred - dit_config["use_x0"] = True - else: - dit_config["use_x0"] = False + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True + else: + dit_config["use_x0"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys From eeb020b9b77e1f3c0c2806bc1e38c7ba9576439e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:33:09 -0800 Subject: [PATCH 028/148] Better chroma radiance and other models vram estimation. (#11278) --- comfy/supported_models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ef8c75c09..834dfcffc 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -965,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 def get_model(self, state_dict, prefix="", device=None): out = model_base.CosmosPredict2(self, device=device) @@ -1289,7 +1289,7 @@ class ChromaRadiance(Chroma): latent_format = comfy.latent_formats.ChromaRadiance # Pixel-space model, no spatial compression for model input. - memory_usage_factor = 0.038 + memory_usage_factor = 0.044 def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device) @@ -1332,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE): "shift": 2.6, } - memory_usage_factor = 1.65 #TODO + memory_usage_factor = 1.95 #TODO unet_extra_config = {} latent_format = latent_formats.Flux @@ -1397,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo): latent_format = latent_formats.HunyuanImage21 - memory_usage_factor = 7.7 + memory_usage_factor = 8.7 supported_inference_dtypes = [torch.bfloat16, torch.float32] @@ -1488,7 +1488,7 @@ class Kandinsky5(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.HunyuanVideo - memory_usage_factor = 1.1 #TODO + memory_usage_factor = 1.25 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float32] @@ -1517,7 +1517,7 @@ class Kandinsky5Image(Kandinsky5): } latent_format = latent_formats.Flux - memory_usage_factor = 1.1 #TODO + memory_usage_factor = 1.25 #TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.Kandinsky5Image(self, device=device) From 338d9ae3bbf24a9a06996cdf1c2f228acc65fd96 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Dec 2025 15:56:33 -0800 Subject: [PATCH 029/148] Make portable updater work with repos in unmerged state. (#11281) --- .ci/update_windows/update.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 59ece5130..fe646a6ed 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -53,6 +53,16 @@ try: repo.stash(ident) except KeyError: print("nothing to stash") # noqa: T201 +except: + print("Could not stash, cleaning index and trying again.") # noqa: T201 + repo.state_cleanup() + repo.index.read_tree(repo.head.peel().tree) + repo.index.write() + try: + repo.stash(ident) + except KeyError: + print("nothing to stash.") # noqa: T201 + backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201 try: From 982876d59a659adb085be5e236aacc4f2c54c19c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 12 Dec 2025 05:29:34 +0200 Subject: [PATCH 030/148] WanMove support (#11247) --- comfy_api/latest/_io.py | 8 + comfy_extras/nodes_wanmove.py | 535 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 544 insertions(+) create mode 100644 comfy_extras/nodes_wanmove.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 79217c813..2b634d172 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO): class AudioEncoderOutput(ComfyTypeIO): Type = Any +@comfytype(io_type="TRACKS") +class Tracks(ComfyTypeIO): + class TrackDict(TypedDict): + track_path: torch.Tensor + track_visibility: torch.Tensor + Type = TrackDict + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any @@ -1894,6 +1901,7 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + "Tracks", # Dynamic Types "MatchType", # "DynamicCombo", diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py new file mode 100644 index 000000000..5f39afa46 --- /dev/null +++ b/comfy_extras/nodes_wanmove.py @@ -0,0 +1,535 @@ +import nodes +import node_helpers +import torch +import torchvision.transforms.functional as TF +import comfy.model_management +import comfy.utils +import numpy as np +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_wan import parse_json_tracks + +# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py +from PIL import Image, ImageDraw + +SKIP_ZERO = False + +def get_pos_emb( + pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings. + pos_emb_dim: int, + theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions. + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim) + + assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even." + pos_k = pos_k.to(device, dtype) + if SKIP_ZERO: + pos_k = pos_k + 1 + batch_size = pos_k.size(0) + + denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype) + # Expand denominator to match the shape needed for broadcasting + denominator_expanded = denominator.view(1, -1).expand(batch_size, -1) + + thetas = theta_func(denominator_expanded, pos_emb_dim) + + # Ensure pos_k is in the correct shape for broadcasting + pos_k_expanded = pos_k.view(-1, 1).to(dtype) + sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas)) + cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas)) + + # Concatenate sine and cosine embeddings along the last dimension + pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1) + + return pos_emb + +def create_pos_embeddings( + pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2] + pred_visibility: torch.Tensor, # the predicted visibility [T, N] + downsample_ratios: list[int], # the ratios for downsampling time, height, and width + height: int, # the height of the feature map + width: int, # the width of the feature map + track_num: int = -1, # the number of tracks to use + t_down_strategy: str = "sample", # the strategy for downsampling time dimension +): + assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension." + + t, n, _ = pred_tracks.shape + t_down, h_down, w_down = downsample_ratios + track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long) + + if track_num == -1: + track_num = n + + tracks_idx = torch.randperm(n)[:track_num] + tracks = pred_tracks[:, tracks_idx] + visibility = pred_visibility[:, tracks_idx] + + for t_idx in range(0, t, t_down): + if t_down_strategy == "sample" or t_idx == 0: + cur_tracks = tracks[t_idx] # [N, 2] + cur_visibility = visibility[t_idx] # [N] + else: + cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0) + cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0) + + for i in range(track_num): + if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height: + continue + x, y = cur_tracks[i] + x, y = int(x // w_down), int(y // h_down) + track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x + + return track_pos # the position embeddings, [N, T', 2], 2 = height, width + +def replace_feature( + vae_feature: torch.Tensor, # [B, C', T', H', W'] + track_pos: torch.Tensor, # [B, N, T', 2] + strength: float = 1.0 +) -> torch.Tensor: + b, _, t, h, w = vae_feature.shape + assert b == track_pos.shape[0], "Batch size mismatch." + n = track_pos.shape[1] + + # Shuffle the trajectory order + track_pos = track_pos[:, torch.randperm(n), :, :] + + # Extract coordinates at time steps ≥ 1 and generate a valid mask + current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2] + mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1] + + # Get all valid indices + valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3] + num_valid = valid_indices.shape[0] + + if num_valid == 0: + return vae_feature + + # Decompose valid indices into each dimension + batch_idx = valid_indices[:, 0] + track_idx = valid_indices[:, 1] + t_rel = valid_indices[:, 2] + t_target = t_rel + 1 # Convert to original time step indices + + # Extract target position coordinates + h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices + w_target = current_pos[batch_idx, track_idx, t_rel, 1].long() + + # Extract source position coordinates (t=0) + h_source = track_pos[batch_idx, track_idx, 0, 0].long() + w_source = track_pos[batch_idx, track_idx, 0, 1].long() + + # Get source features and assign to target positions + src_features = vae_feature[batch_idx, :, 0, h_source, w_source] + dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target] + + vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength + + + return vae_feature + +# Visualize functions + +def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0): + draw = ImageDraw.Draw(overlay, 'RGBA') + points = points[::-1] + + # Compute total length + total_length = 0 + segment_lengths = [] + for i in range(len(points) - 1): + dx = points[i + 1][0] - points[i][0] + dy = points[i + 1][1] - points[i][1] + length = (dx * dx + dy * dy) ** 0.5 + segment_lengths.append(length) + total_length += length + + if total_length == 0: + return + + accumulated_length = 0 + + # Draw the gradient polyline + for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])): + segment_length = segment_lengths[idx] + steps = max(int(segment_length), 1) + + for i in range(steps): + current_length = accumulated_length + (i / steps) * segment_length + ratio = current_length / total_length + + alpha = int(255 * (1 - ratio) * opacity) + color = (*start_color, alpha) + + x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps) + y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps) + + dynamic_line_width = max(int(line_width * (1 - ratio)), 1) + draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width) + + accumulated_length += segment_length + + +def add_weighted(rgb, track): + rgb = np.array(rgb) # [H, W, C] "RGB" + track = np.array(track) # [H, W, C] "RGBA" + + alpha = track[:, :, 3] / 255.0 + alpha = np.stack([alpha] * 3, axis=-1) + blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha) + + return Image.fromarray(blend_img.astype(np.uint8)) + +def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16): + color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)] + + video = video.byte().cpu().numpy() # (81, 480, 832, 3) + tracks = tracks[0].long().detach().cpu().numpy() + if visibility is not None: + visibility = visibility[0].detach().cpu().numpy() + + num_frames, height, width = video.shape[:3] + num_tracks = tracks.shape[1] + alpha_opacity = int(255 * opacity) + + output_frames = [] + for t in range(num_frames): + frame_rgb = video[t].astype(np.float32) + + # Create a single RGBA overlay for all tracks in this frame + overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + draw_overlay = ImageDraw.Draw(overlay) + + polyline_data = [] + + # Draw all circles on a single overlay + for n in range(num_tracks): + if visibility is not None and visibility[t, n] == 0: + continue + + track_coord = tracks[t, n] + color = color_map[n % len(color_map)] + circle_color = color + (alpha_opacity,) + + draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size), + fill=circle_color + ) + + # Store polyline data for batch processing + tracks_coord = tracks[max(t - track_frame, 0):t + 1, n] + if len(tracks_coord) > 1: + polyline_data.append((tracks_coord, color)) + + # Blend circles overlay once + overlay_np = np.array(overlay) + alpha = overlay_np[:, :, 3:4] / 255.0 + frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha) + + # Draw all polylines on a single overlay + if polyline_data: + polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + for tracks_coord, color in polyline_data: + _draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity) + + # Blend polylines overlay once + polyline_np = np.array(polyline_overlay) + alpha = polyline_np[:, :, 3:4] / 255.0 + frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha) + + output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8))) + + return output_frames + + +class WanMoveVisualizeTracks(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveVisualizeTracks", + category="conditioning/video_models", + inputs=[ + io.Image.Input("images"), + io.Tracks.Input("tracks", optional=True), + io.Int.Input("line_resolution", default=24, min=1, max=1024), + io.Int.Input("circle_size", default=12, min=1, max=128), + io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01), + io.Int.Input("line_width", default=16, min=1, max=128), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput: + if tracks is None: + return io.NodeOutput(images) + + track_path = tracks["track_path"].unsqueeze(0) + track_visibility = tracks["track_visibility"].unsqueeze(0) + images_in = images * 255.0 + if images_in.shape[0] != track_path.shape[1]: + repeat_count = track_path.shape[1] // images.shape[0] + images_in = images_in.repeat(repeat_count, 1, 1, 1) + track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width) + track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float() + + return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device())) + + +class WanMoveTracksFromCoords(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveTracksFromCoords", + category="conditioning/video_models", + inputs=[ + io.String.Input("track_coords", force_input=True, default="[]", optional=True), + io.Mask.Input("track_mask", optional=True), + ], + outputs=[ + io.Tracks.Output(), + io.Int.Output(display_name="track_length"), + ], + ) + + @classmethod + def execute(cls, track_coords, track_mask=None) -> io.NodeOutput: + device=comfy.model_management.intermediate_device() + + tracks_data = parse_json_tracks(track_coords) + track_length = len(tracks_data[0]) + + track_list = [ + [[track[frame]['x'], track[frame]['y']] for track in tracks_data] + for frame in range(len(tracks_data[0])) + ] + tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2] + + num_tracks = tracks.shape[-2] + if track_mask is None: + track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device) + else: + track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) + + out_track_info = {} + out_track_info["track_path"] = tracks + out_track_info["track_visibility"] = track_visibility + return io.NodeOutput(out_track_info, track_length) + + +class GenerateTracks(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GenerateTracks", + category="conditioning/video_models", + inputs=[ + io.Int.Input("width", default=832, min=16, max=4096, step=16), + io.Int.Input("height", default=480, min=16, max=4096, step=16), + io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."), + io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."), + io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."), + io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."), + io.Int.Input("num_frames", default=81, min=1, max=1024), + io.Int.Input("num_tracks", default=5, min=1, max=100), + io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."), + io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."), + io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."), + io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."), + io.Combo.Input( + "interpolation", + options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"], + tooltip="Controls the timing/speed of movement along the path.", + ), + io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."), + ], + outputs=[ + io.Tracks.Output(), + io.Int.Output(display_name="track_length"), + ], + ) + + @classmethod + def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks, + track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput: + device = comfy.model_management.intermediate_device() + track_length = num_frames + + # normalized coordinates to pixel coordinates + start_x_px = start_x * width + start_y_px = start_y * height + mid_x_px = mid_x * width + mid_y_px = mid_y * height + end_x_px = end_x * width + end_y_px = end_y * height + + track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional + + t = torch.linspace(0, 1, num_frames, device=device) + if interpolation == "constant": # All points stay at start position + interp_values = torch.zeros_like(t) + elif interpolation == "linear": + interp_values = t + elif interpolation == "ease_in": + interp_values = t ** 2 + elif interpolation == "ease_out": + interp_values = 1 - (1 - t) ** 2 + elif interpolation == "ease_in_out": + interp_values = t * t * (3 - 2 * t) + + if bezier: # apply interpolation to t for timing control along the bezier path + t_interp = interp_values + one_minus_t = 1 - t_interp + x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px + y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px + tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px) + tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px) + else: # calculate base x and y positions for each frame (center track) + x_positions = start_x_px + (end_x_px - start_x_px) * interp_values + y_positions = start_y_px + (end_y_px - start_y_px) * interp_values + # For non-bezier, tangent is constant (direction from start to end) + tangent_x = torch.full_like(t, end_x_px - start_x_px) + tangent_y = torch.full_like(t, end_y_px - start_y_px) + + track_list = [] + for frame_idx in range(num_frames): + # Calculate perpendicular direction at this frame + tx = tangent_x[frame_idx].item() + ty = tangent_y[frame_idx].item() + length = (tx ** 2 + ty ** 2) ** 0.5 + + if length > 0: # Perpendicular unit vector (rotate 90 degrees) + perp_x = -ty / length + perp_y = tx / length + else: # If tangent is zero, spread horizontally + perp_x = 1.0 + perp_y = 0.0 + + frame_tracks = [] + for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2 + offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px + track_x = x_positions[frame_idx].item() + perp_x * offset + track_y = y_positions[frame_idx].item() + perp_y * offset + frame_tracks.append([track_x, track_y]) + track_list.append(frame_tracks) + + tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2] + + if track_mask is None: + track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device) + else: + track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) + + out_track_info = {} + out_track_info["track_path"] = tracks + out_track_info["track_visibility"] = track_visibility + return io.NodeOutput(out_track_info, track_length) + + +class WanMoveConcatTrack(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveConcatTrack", + category="conditioning/video_models", + inputs=[ + io.Tracks.Input("tracks_1"), + io.Tracks.Input("tracks_2", optional=True), + ], + outputs=[ + io.Tracks.Output(), + ], + ) + + @classmethod + def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput: + if tracks_2 is None: + return io.NodeOutput(tracks_1) + + tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension + mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1) + + out_track_info = {} + out_track_info["track_path"] = tracks_out + out_track_info["track_visibility"] = mask_out + return io.NodeOutput(out_track_info) + + +class WanMoveTrackToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveTrackToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Tracks.Input("tracks", optional=True), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + device=comfy.model_management.intermediate_device() + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + if tracks is not None and strength > 0.0: + tracks_path = tracks["track_path"][:length] # [T, N, 2] + num_tracks = tracks_path.shape[-2] + + track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device)) + + track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks) + track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size) + concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength) + else: + concat_latent_image_pos = concat_latent_image + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class WanMoveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanMoveTrackToVideo, + WanMoveTracksFromCoords, + WanMoveConcatTrack, + WanMoveVisualizeTracks, + GenerateTracks, + ] + +async def comfy_entrypoint() -> WanMoveExtension: + return WanMoveExtension() diff --git a/nodes.py b/nodes.py index 8d28a725d..8678f510a 100644 --- a/nodes.py +++ b/nodes.py @@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes(): "nodes_logic.py", "nodes_nop.py", "nodes_kandinsky5.py", + "nodes_wanmove.py", ] import_failed = [] From 5495589db38409353a85b06df7d10f8de2f9c78d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Dec 2025 20:32:27 -0800 Subject: [PATCH 031/148] Respect the dtype the op was initialized in for non quant mixed op. (#11282) --- comfy/ops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6f34d50fc..6ae6e791a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -497,8 +497,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} + if dtype is None: + dtype = MixedPrecisionOps._compute_dtype + + self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features @@ -530,7 +532,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + dtype = self.factory_kwargs["dtype"] + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) + if dtype != MixedPrecisionOps._compute_dtype: + self.comfy_cast_weights = True else: self.quant_format = layer_conf.get("format", None) if not self._full_precision_mm: From 908fd7d7496f6de88722263e1e00fcd3d22e584f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:18:31 +0200 Subject: [PATCH 032/148] feat(api-nodes): new TextToVideoWithAudio and ImageToVideoWithAudio nodes (#11267) --- comfy_api_nodes/apis/kling_api.py | 28 ++++- comfy_api_nodes/nodes_kling.py | 169 ++++++++++++++++++++++++++---- 2 files changed, 174 insertions(+), 23 deletions(-) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index d8949f8ac..80a758466 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel): url: str = Field(..., description="URL for generated image") -class OmniTaskStatusResults(BaseModel): +class TaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) images: list[TaskStatusImageResult] | None = Field(None) -class OmniTaskStatusResponseData(BaseModel): +class TaskStatusResponseData(BaseModel): created_at: int | None = Field(None, description="Task creation time") updated_at: int | None = Field(None, description="Task update time") task_status: str | None = None task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_id: str | None = Field(None, description="Task ID") - task_result: OmniTaskStatusResults | None = Field(None) + task_result: TaskStatusResults | None = Field(None) -class OmniTaskStatusResponse(BaseModel): +class TaskStatusResponse(BaseModel): code: int | None = Field(None, description="Error code") message: str | None = Field(None, description="Error message") request_id: str | None = Field(None, description="Request ID") - data: OmniTaskStatusResponseData | None = Field(None) + data: TaskStatusResponseData | None = Field(None) class OmniImageParamImage(BaseModel): @@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel): mode: str = Field("pro") n: int | None = Field(1, le=9) image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) + + +class TextToVideoWithAudioRequest(BaseModel): + model_name: str = Field(..., description="kling-v2-6") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + sound: str = Field(..., description="'on' or 'off'") + + +class ImageToVideoWithAudioRequest(BaseModel): + model_name: str = Field(..., description="kling-v2-6") + image: str = Field(...) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + sound: str = Field(..., description="'on' or 'off'") diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index a2cc87d84..e545fe490 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -50,6 +50,7 @@ from comfy_api_nodes.apis import ( KlingSingleImageEffectModelName, ) from comfy_api_nodes.apis.kling_api import ( + ImageToVideoWithAudioRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import ( OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, - OmniTaskStatusResponse, + TaskStatusResponse, + TextToVideoWithAudioRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -242,7 +244,7 @@ def normalize_omni_prompt_references(prompt: str) -> str: return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) -async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput: if response.code: raise RuntimeError( f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" @@ -250,7 +252,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), max_poll_attempts=160, ) @@ -483,12 +485,12 @@ async def execute_image2video( task_id = task_creation_response.data.task_id final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), - response_model=KlingImage2VideoResponse, - estimated_duration=AVERAGE_DURATION_I2V, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, + estimated_duration=AVERAGE_DURATION_I2V, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -834,7 +836,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProText2VideoRequest( model_name=model_name, prompt=prompt, @@ -929,7 +931,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProFirstLastFrameRequest( model_name=model_name, prompt=prompt, @@ -997,7 +999,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1081,7 +1083,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1162,7 +1164,7 @@ class OmniProEditVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1237,7 +1239,7 @@ class OmniProImageNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, data=OmniProImageRequest( model_name=model_name, prompt=prompt, @@ -1253,7 +1255,7 @@ class OmniProImageNode(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), - response_model=OmniTaskStatusResponse, + response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) @@ -1328,9 +1330,8 @@ class KlingImage2VideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingImage2VideoNode", - display_name="Kling Image to Video", + display_name="Kling Image(First Frame) to Video", category="api node/video/Kling", - description="Kling Image to Video Node", inputs=[ IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -2034,6 +2035,136 @@ class KlingImageGenerationNode(IO.ComfyNode): return IO.NodeOutput(await image_result_to_node_output(images)) +class TextToVideoWithAudio(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingTextToVideoWithAudio", + display_name="Kling Text to Video with Audio", + category="api node/video/Kling", + inputs=[ + IO.Combo.Input("model_name", options=["kling-v2-6"]), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), + IO.Combo.Input("mode", options=["pro"]), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + IO.Boolean.Input("generate_audio", default=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + mode: str, + aspect_ratio: str, + duration: int, + generate_audio: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), + response_model=TaskStatusResponse, + data=TextToVideoWithAudioRequest( + model_name=model_name, + prompt=prompt, + mode=mode, + aspect_ratio=aspect_ratio, + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class ImageToVideoWithAudio(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingImageToVideoWithAudio", + display_name="Kling Image(First Frame) to Video with Audio", + category="api node/video/Kling", + inputs=[ + IO.Combo.Input("model_name", options=["kling-v2-6"]), + IO.Image.Input("start_frame"), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), + IO.Combo.Input("mode", options=["pro"]), + IO.Combo.Input("duration", options=[5, 10]), + IO.Boolean.Input("generate_audio", default=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + start_frame: Input.Image, + prompt: str, + mode: str, + duration: int, + generate_audio: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model_name, + image=(await upload_images_to_comfyapi(cls, start_frame))[0], + prompt=prompt, + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2057,6 +2188,8 @@ class KlingExtension(ComfyExtension): OmniProVideoToVideoNode, OmniProEditVideoNode, OmniProImageNode, + TextToVideoWithAudio, + ImageToVideoWithAudio, ] From c5a47a16924e1be96241553a1448b298e57e50a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Dec 2025 08:49:35 -0800 Subject: [PATCH 033/148] Fix bias dtype issue in mixed ops. (#11293) --- comfy/ops.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6ae6e791a..0384c8717 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -504,10 +504,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.in_features = in_features self.out_features = out_features - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) - else: - self.register_parameter("bias", None) + self._has_bias = bias self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm @@ -536,6 +533,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) if dtype != MixedPrecisionOps._compute_dtype: self.comfy_cast_weights = True + if self._has_bias: + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) + else: + self.register_parameter("bias", None) else: self.quant_format = layer_conf.get("format", None) if not self._full_precision_mm: @@ -565,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec requires_grad=False ) + if self._has_bias: + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) + else: + self.register_parameter("bias", None) + for param_name in qconfig["parameters"]: param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) From da2bfb5b0af26c7a1c44ec951dbd0fffe413c793 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Dec 2025 22:39:11 -0800 Subject: [PATCH 034/148] Basic implementation of z image fun control union 2.0 (#11304) The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again. --- comfy/ldm/lumina/controlnet.py | 95 +++++++++++++++++++++++-------- comfy/ldm/lumina/model.py | 16 +++++- comfy/model_patcher.py | 3 + comfy_extras/nodes_model_patch.py | 72 +++++++++++++++++------ 4 files changed, 142 insertions(+), 44 deletions(-) diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py index fd7ce3b5c..8e2de7977 100644 --- a/comfy/ldm/lumina/controlnet.py +++ b/comfy/ldm/lumina/controlnet.py @@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module): ffn_dim_multiplier: float = (8.0 / 3.0), norm_eps: float = 1e-5, qk_norm: bool = True, + n_control_layers=6, + control_in_dim=16, + additional_in_dim=0, + broken=False, + refiner_control=False, dtype=None, device=None, operations=None, @@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module): super().__init__() operation_settings = {"operations": operations, "device": device, "dtype": dtype} - self.additional_in_dim = 0 - self.control_in_dim = 16 + self.broken = broken + self.additional_in_dim = additional_in_dim + self.control_in_dim = control_in_dim n_refiner_layers = 2 - self.n_control_layers = 6 + self.n_control_layers = n_control_layers self.control_layers = nn.ModuleList( [ ZImageControlTransformerBlock( @@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module): all_x_embedder = {} patch_size = 2 f_patch_size = 1 - x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + self.refiner_control = refiner_control + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) - self.control_noise_refiner = nn.ModuleList( - [ - JointTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - qk_norm, - modulation=True, - z_image_modulation=True, - operation_settings=operation_settings, - ) - for layer_id in range(n_refiner_layers) - ] - ) + if self.refiner_control: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageControlTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=layer_id, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + else: + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): patch_size = 2 @@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module): control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) x_attn_mask = None - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + if not self.refiner_control: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + if self.refiner_control: + if self.broken: + if layer_id == 0: + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if layer_id > 0: + out = None + for i in range(1, len(self.control_layers)): + o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if out is None: + out = o + + return (out, control_context) + else: + return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + else: + return (None, control_context) + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c47df49ca..96cb37fa6 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -536,6 +536,7 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device + orig_x = x if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple @@ -572,13 +573,21 @@ class NextDiT(nn.Module): freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + patches = transformer_options.get("patches", {}) + # refine context for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) padded_img_mask = None - for layer in self.noise_refiner: + x_input = x + for i, layer in enumerate(self.noise_refiner): x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + if "noise_refiner" in patches: + for p in patches["noise_refiner"]: + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + if "img" in out: + x = out["img"] padded_full_embed = torch.cat((cap_feats, x), dim=1) mask = None @@ -622,14 +631,15 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) + img_input = img for i, layer in enumerate(self.layers): img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) if "img" in out: img[:, cap_size[0]:] = out["img"] if "txt" in out: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a486c2723..93d26c690 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -454,6 +454,9 @@ class ModelPatcher: def set_model_post_input_patch(self, patch): self.set_model_patch(patch, "post_input") + def set_model_noise_refiner_patch(self, patch): + self.set_model_patch(patch, "noise_refiner") + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index c61810dbf..ec0e790dc 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -243,7 +243,13 @@ class ModelPatchLoader: model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) - model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + config = {} + if 'control_layers.14.adaLN_modulation.0.weight' in sd: + config['n_control_layers'] = 15 + config['additional_in_dim'] = 17 + config['refiner_control'] = True + config['broken'] = True + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -297,56 +303,86 @@ class DiffSynthCnetPatch: return [self.model_patch] class ZImageControlPatch: - def __init__(self, model_patch, vae, image, strength): + def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None): self.model_patch = model_patch self.vae = vae self.image = image + self.inpaint_image = inpaint_image + self.mask = mask self.strength = strength self.encoded_image = self.encode_latent_cond(image) self.encoded_image_size = (image.shape[1], image.shape[2]) self.temp_data = None - def encode_latent_cond(self, image): - latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) - return latent_image + def encode_latent_cond(self, control_image, inpaint_image=None): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.zeros_like(latent_image)[:, :1] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + if inpaint_image is None: + inpaint_image = torch.ones_like(control_image) * 0.5 + + inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image)) + + return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1) + else: + return latent_image def __call__(self, kwargs): x = kwargs.get("x") img = kwargs.get("img") + img_input = kwargs.get("img_input") txt = kwargs.get("txt") pe = kwargs.get("pe") vec = kwargs.get("vec") block_index = kwargs.get("block_index") + block_type = kwargs.get("block_type", "") spacial_compression = self.vae.spacial_compression_encode() if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + inpaint_scaled = None + if self.inpaint_image is not None: + inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled) self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) comfy.model_management.load_models_gpu(loaded_models) - cnet_index = (block_index // 5) - cnet_index_float = (block_index / 5) + cnet_blocks = self.model_patch.model.n_control_layers + div = round(30 / cnet_blocks) + + cnet_index = (block_index // div) + cnet_index_float = (block_index / div) kwargs.pop("img") # we do ops in place kwargs.pop("txt") - cnet_blocks = self.model_patch.model.n_control_layers if cnet_index_float > (cnet_blocks - 1): self.temp_data = None return kwargs if self.temp_data is None or self.temp_data[0] > cnet_index: - self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + if block_type == "noise_refiner": + self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + else: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) - while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + if block_type == "noise_refiner": next_layer = self.temp_data[0] + 1 - self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + if self.temp_data[1][0] is not None: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + else: + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) - if cnet_index_float == self.temp_data[0]: - img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) - if cnet_blocks == self.temp_data[0] + 1: - self.temp_data = None + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None return kwargs @@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet: mask = 1.0 - mask if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): - model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask) + model_patched.set_model_noise_refiner_patch(patch) + model_patched.set_model_double_block_patch(patch) else: model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From 971cefe7d4ca15c949d5d901a663cb66562a4f10 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 13 Dec 2025 15:45:23 -0800 Subject: [PATCH 035/148] Fix pytorch warnings. (#11314) --- comfy/ops.py | 2 +- comfy/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 0384c8717..16889bb82 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec quant_conf = {"format": self.quant_format} if self._full_precision_mm: quant_conf["full_precision_matrix_mult"] = True - sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd def _forward(self, input, weight, bias): diff --git a/comfy/utils.py b/comfy/utils.py index 9dc0d76ac..3866cda2e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): if quant_metadata is not None: layers = quant_metadata["layers"] for k, v in layers.items(): - state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata From 6592bffc609da4738b111dbffca1f473972f3574 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 14 Dec 2025 13:03:29 +0800 Subject: [PATCH 036/148] seeds_2: add phi_2 variant and sampler node (#11309) * Add phi_2 solver type to seeds_2 * Add sampler node of seeds_2 --- comfy/k_diffusion/sampling.py | 15 ++++++++++++--- comfy_extras/nodes_custom_sampler.py | 26 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0e2cda291..753c66afa 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() -def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): +def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ + if solver_type not in {"phi_1", "phi_2"}: + raise ValueError("solver_type must be 'phi_1' or 'phi_2'") + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - denoised_d = torch.lerp(denoised, denoised_2, fac) - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + if solver_type == "phi_1": + denoised_d = torch.lerp(denoised, denoised_2, fac) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + elif solver_type == "phi_2": + b2 = ei_h_phi_2(-h_eta) / r + b1 = ei_h_phi_1(-h_eta) - b2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2) + if inject_noise: segment_factor = (r - 1) * h * eta sde_noise = sde_noise * segment_factor.exp() diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index fbb080886..71ea4e9ec 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode): get_sampler = execute +class SamplerSEEDS2(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSEEDS2", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), + io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput: + sampler_name = "seeds_2" + sampler = comfy.samplers.ksampler( + sampler_name, + {"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type}, + ) + return io.NodeOutput(sampler) + + class Noise_EmptyNoise: def __init__(self): self.seed = 0 @@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension): SamplerDPMAdaptative, SamplerER_SDE, SamplerSASolver, + SamplerSEEDS2, SplitSigmas, SplitSigmasDenoise, FlipSigmas, From 5ac3b26a7dedb9b13c681abe8733c54f13353273 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 14 Dec 2025 01:02:50 -0800 Subject: [PATCH 037/148] Update warning for old pytorch version. (#11319) Versions below 2.4 are no longer supported. We will not break support on purpose but will not fix it if we do. --- comfy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 3866cda2e..8d4e2b445 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in ALWAYS_SAFE_LOAD = True logging.info("Checkpoint files will always be loaded safely.") else: - logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") + logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: From a5e85017d8574cb99024d320f7a53a77a9e6aa5a Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 16 Dec 2025 04:24:01 +0900 Subject: [PATCH 038/148] bump manager requirments to the 4.0.3b5 (#11324) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index b95cefb74..5ef0d3a1d 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b4 +comfyui_manager==4.0.3b5 From 51347f9fb8a8e60d3add049c6f241822c84c8a87 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 16 Dec 2025 05:28:55 +0800 Subject: [PATCH 039/148] chore: update workflow templates to v0.7.59 (#11337) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9e9b25328..117260515 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.8 -comfyui-workflow-templates==0.7.54 +comfyui-workflow-templates==0.7.59 comfyui-embedded-docs==0.3.1 torch torchsde From 5cb1e0c9a0439f1f95a0b372474bd4845e38009c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 13:49:29 -0800 Subject: [PATCH 040/148] Disable guards on transformer_options when torch.compile (#11317) --- comfy_extras/nodes_torch_compile.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index adbeece2f..c43e8ad63 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -2,6 +2,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.torch_helpers import set_torch_compile_wrapper +def skip_torch_compile_dict(guard_entries): + return [("transformer_options" not in entry.name) for entry in guard_entries] class TorchCompileModel(io.ComfyNode): @classmethod @@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode): @classmethod def execute(cls, model, backend) -> io.NodeOutput: m = model.clone() - set_torch_compile_wrapper(model=m, backend=backend) + set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict}) return io.NodeOutput(m) From af91eb6c9931d0a2c99cf8a6d4974a6abf9a09fa Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 16 Dec 2025 01:30:24 +0200 Subject: [PATCH 041/148] api-nodes: drop Kling v1 model (#11307) --- comfy_api_nodes/nodes_kling.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index e545fe490..1a6364fa0 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -105,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320 MODE_TEXT2VIDEO = { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), - "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"), "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), @@ -129,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document MODE_START_END_FRAME = { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), @@ -754,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode): IO.Combo.Input( "mode", options=modes, - default=modes[4], + default=modes[8], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1489,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.Combo.Input( "mode", options=modes, - default=modes[8], + default=modes[6], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1952,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Combo.Input( "model_name", options=[i.value for i in KlingImageGenModelName], - default="kling-v1", + default="kling-v2", ), IO.Combo.Input( "aspect_ratio", From 33c7f1179d4a961e4ca1dd78188c5134e0ee8e8c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 16 Dec 2025 01:32:29 +0200 Subject: [PATCH 042/148] drop Pika API nodes (#11306) --- comfy_api_nodes/apis/pika_api.py | 100 ------ comfy_api_nodes/nodes_pika.py | 575 ------------------------------- nodes.py | 1 - 3 files changed, 676 deletions(-) delete mode 100644 comfy_api_nodes/apis/pika_api.py delete mode 100644 comfy_api_nodes/nodes_pika.py diff --git a/comfy_api_nodes/apis/pika_api.py b/comfy_api_nodes/apis/pika_api.py deleted file mode 100644 index 232558cd7..000000000 --- a/comfy_api_nodes/apis/pika_api.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional -from enum import Enum -from pydantic import BaseModel, Field - - -class Pikaffect(str, Enum): - Cake_ify = "Cake-ify" - Crumble = "Crumble" - Crush = "Crush" - Decapitate = "Decapitate" - Deflate = "Deflate" - Dissolve = "Dissolve" - Explode = "Explode" - Eye_pop = "Eye-pop" - Inflate = "Inflate" - Levitate = "Levitate" - Melt = "Melt" - Peel = "Peel" - Poke = "Poke" - Squish = "Squish" - Ta_da = "Ta-da" - Tear = "Tear" - - -class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel): - aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)') - duration: Optional[int] = Field(5) - ingredientsMode: str = Field(...) - negativePrompt: Optional[str] = Field(None) - promptText: Optional[str] = Field(None) - resolution: Optional[str] = Field('1080p') - seed: Optional[int] = Field(None) - - -class PikaGenerateResponse(BaseModel): - video_id: str = Field(...) - - -class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): - duration: Optional[int] = 5 - negativePrompt: Optional[str] = Field(None) - promptText: Optional[str] = Field(None) - resolution: Optional[str] = '1080p' - seed: Optional[int] = Field(None) - - -class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel): - duration: Optional[int] = Field(None, ge=5, le=10) - negativePrompt: Optional[str] = Field(None) - promptText: str = Field(...) - resolution: Optional[str] = '1080p' - seed: Optional[int] = Field(None) - - -class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel): - aspectRatio: Optional[float] = Field( - 1.7777777777777777, - description='Aspect ratio (width / height)', - ge=0.4, - le=2.5, - ) - duration: Optional[int] = 5 - negativePrompt: Optional[str] = Field(None) - promptText: str = Field(...) - resolution: Optional[str] = '1080p' - seed: Optional[int] = Field(None) - - -class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel): - negativePrompt: Optional[str] = Field(None) - promptText: Optional[str] = Field(None) - seed: Optional[int] = Field(None) - - -class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel): - negativePrompt: Optional[str] = Field(None) - pikaffect: Optional[str] = None - promptText: Optional[str] = Field(None) - seed: Optional[int] = Field(None) - - -class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel): - negativePrompt: Optional[str] = Field(None) - promptText: Optional[str] = Field(None) - seed: Optional[int] = Field(None) - modifyRegionRoi: Optional[str] = Field(None) - - -class PikaStatusEnum(str, Enum): - queued = "queued" - started = "started" - finished = "finished" - failed = "failed" - - -class PikaVideoResponse(BaseModel): - id: str = Field(...) - progress: Optional[int] = Field(None) - status: PikaStatusEnum - url: Optional[str] = Field(None) diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py deleted file mode 100644 index acd88c391..000000000 --- a/comfy_api_nodes/nodes_pika.py +++ /dev/null @@ -1,575 +0,0 @@ -""" -Pika x ComfyUI API Nodes - -Pika API docs: https://pika-827374fb.mintlify.app/api-reference -""" -from __future__ import annotations - -from io import BytesIO -import logging -from typing import Optional - -import torch - -from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apis import pika_api as pika_defs -from comfy_api_nodes.util import ( - validate_string, - download_url_to_video_output, - tensor_to_bytesio, - ApiEndpoint, - sync_op, - poll_op, -) - - -PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" -PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" -PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects" - -PIKA_API_VERSION = "2.2" -PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v" -PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v" -PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes" -PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes" - -PATH_VIDEO_GET = "/proxy/pika/videos" - - -async def execute_task( - task_id: str, - cls: type[IO.ComfyNode], -) -> IO.NodeOutput: - final_response: pika_defs.PikaVideoResponse = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"), - response_model=pika_defs.PikaVideoResponse, - status_extractor=lambda response: (response.status.value if response.status else None), - progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), - estimated_duration=60, - max_poll_attempts=240, - ) - if not final_response.url: - error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" - logging.error(error_msg) - raise Exception(error_msg) - video_url = final_response.url - logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - return IO.NodeOutput(await download_url_to_video_output(video_url)) - - -def get_base_inputs_types() -> list[IO.Input]: - """Get the base required inputs types common to all Pika nodes.""" - return [ - IO.String.Input("prompt_text", multiline=True), - IO.String.Input("negative_prompt", multiline=True), - IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), - IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"), - IO.Combo.Input("duration", options=[5, 10], default=5), - ] - - -class PikaImageToVideo(IO.ComfyNode): - """Pika 2.2 Image to Video Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="PikaImageToVideoNode2_2", - display_name="Pika Image to Video", - description="Sends an image and prompt to the Pika API v2.2 to generate a video.", - category="api node/video/Pika", - inputs=[ - IO.Image.Input("image", tooltip="The image to convert to video"), - *get_base_inputs_types(), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - image: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - ) -> IO.NodeOutput: - image_bytes_io = tensor_to_bytesio(image) - pika_files = {"image": ("image.png", image_bytes_io, "image/png")} - pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - ) - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - ) - return await execute_task(initial_operation.video_id, cls) - - -class PikaTextToVideoNode(IO.ComfyNode): - """Pika Text2Video v2.2 Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="PikaTextToVideoNode2_2", - display_name="Pika Text to Video", - description="Sends a text prompt to the Pika API v2.2 to generate a video.", - category="api node/video/Pika", - inputs=[ - *get_base_inputs_types(), - IO.Float.Input( - "aspect_ratio", - step=0.001, - min=0.4, - max=2.5, - default=1.7777777777777777, - tooltip="Aspect ratio (width / height)", - ) - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - aspect_ratio: float, - ) -> IO.NodeOutput: - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - aspectRatio=aspect_ratio, - ), - content_type="application/x-www-form-urlencoded", - ) - return await execute_task(initial_operation.video_id, cls) - - -class PikaScenes(IO.ComfyNode): - """PikaScenes v2.2 Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="PikaScenesV2_2", - display_name="Pika Scenes (Video Image Composition)", - description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.", - category="api node/video/Pika", - inputs=[ - *get_base_inputs_types(), - IO.Combo.Input( - "ingredients_mode", - options=["creative", "precise"], - default="creative", - ), - IO.Float.Input( - "aspect_ratio", - step=0.001, - min=0.4, - max=2.5, - default=1.7777777777777777, - tooltip="Aspect ratio (width / height)", - ), - IO.Image.Input( - "image_ingredient_1", - optional=True, - tooltip="Image that will be used as ingredient to create a video.", - ), - IO.Image.Input( - "image_ingredient_2", - optional=True, - tooltip="Image that will be used as ingredient to create a video.", - ), - IO.Image.Input( - "image_ingredient_3", - optional=True, - tooltip="Image that will be used as ingredient to create a video.", - ), - IO.Image.Input( - "image_ingredient_4", - optional=True, - tooltip="Image that will be used as ingredient to create a video.", - ), - IO.Image.Input( - "image_ingredient_5", - optional=True, - tooltip="Image that will be used as ingredient to create a video.", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - ingredients_mode: str, - aspect_ratio: float, - image_ingredient_1: Optional[torch.Tensor] = None, - image_ingredient_2: Optional[torch.Tensor] = None, - image_ingredient_3: Optional[torch.Tensor] = None, - image_ingredient_4: Optional[torch.Tensor] = None, - image_ingredient_5: Optional[torch.Tensor] = None, - ) -> IO.NodeOutput: - all_image_bytes_io = [] - for image in [ - image_ingredient_1, - image_ingredient_2, - image_ingredient_3, - image_ingredient_4, - image_ingredient_5, - ]: - if image is not None: - all_image_bytes_io.append(tensor_to_bytesio(image)) - - pika_files = [ - ("images", (f"image_{i}.png", image_bytes_io, "image/png")) - for i, image_bytes_io in enumerate(all_image_bytes_io) - ] - - pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost( - ingredientsMode=ingredients_mode, - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - aspectRatio=aspect_ratio, - ) - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_PIKASCENES, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - ) - - return await execute_task(initial_operation.video_id, cls) - - -class PikAdditionsNode(IO.ComfyNode): - """Pika Pikadditions Node. Add an image into a video.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="Pikadditions", - display_name="Pikadditions (Video Object Insertion)", - description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.", - category="api node/video/Pika", - inputs=[ - IO.Video.Input("video", tooltip="The video to add an image to."), - IO.Image.Input("image", tooltip="The image to add to the video."), - IO.String.Input("prompt_text", multiline=True), - IO.String.Input("negative_prompt", multiline=True), - IO.Int.Input( - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - video: VideoInput, - image: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - ) -> IO.NodeOutput: - video_bytes_io = BytesIO() - video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) - video_bytes_io.seek(0) - - image_bytes_io = tensor_to_bytesio(image) - pika_files = { - "video": ("video.mp4", video_bytes_io, "video/mp4"), - "image": ("image.png", image_bytes_io, "image/png"), - } - pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - ) - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - ) - - return await execute_task(initial_operation.video_id, cls) - - -class PikaSwapsNode(IO.ComfyNode): - """Pika Pikaswaps Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="Pikaswaps", - display_name="Pika Swaps (Video Object Replacement)", - description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.", - category="api node/video/Pika", - inputs=[ - IO.Video.Input("video", tooltip="The video to swap an object in."), - IO.Image.Input( - "image", - tooltip="The image used to replace the masked object in the video.", - optional=True, - ), - IO.Mask.Input( - "mask", - tooltip="Use the mask to define areas in the video to replace.", - optional=True, - ), - IO.String.Input("prompt_text", multiline=True, optional=True), - IO.String.Input("negative_prompt", multiline=True, optional=True), - IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True), - IO.String.Input( - "region_to_modify", - multiline=True, - optional=True, - tooltip="Plaintext description of the object / region to modify.", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - video: VideoInput, - image: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - prompt_text: str = "", - negative_prompt: str = "", - seed: int = 0, - region_to_modify: str = "", - ) -> IO.NodeOutput: - video_bytes_io = BytesIO() - video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) - video_bytes_io.seek(0) - pika_files = { - "video": ("video.mp4", video_bytes_io, "video/mp4"), - } - if mask is not None: - pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png") - if image is not None: - pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png") - - pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - modifyRegionRoi=region_to_modify if region_to_modify else None, - ) - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_PIKASWAPS, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - ) - return await execute_task(initial_operation.video_id, cls) - - -class PikaffectsNode(IO.ComfyNode): - """Pika Pikaffects Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="Pikaffects", - display_name="Pikaffects (Video Effects)", - description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear", - category="api node/video/Pika", - inputs=[ - IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."), - IO.Combo.Input( - "pikaffect", options=pika_defs.Pikaffect, default="Cake-ify" - ), - IO.String.Input("prompt_text", multiline=True), - IO.String.Input("negative_prompt", multiline=True), - IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - image: torch.Tensor, - pikaffect: str, - prompt_text: str, - negative_prompt: str, - seed: int, - ) -> IO.NodeOutput: - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( - pikaffect=pikaffect, - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - ), - files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, - content_type="multipart/form-data", - ) - return await execute_task(initial_operation.video_id, cls) - - -class PikaStartEndFrameNode(IO.ComfyNode): - """PikaFrames v2.2 Node.""" - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="PikaStartEndFrameNode2_2", - display_name="Pika Start and End Frame to Video", - description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.", - category="api node/video/Pika", - inputs=[ - IO.Image.Input("image_start", tooltip="The first image to combine."), - IO.Image.Input("image_end", tooltip="The last image to combine."), - *get_base_inputs_types(), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - image_start: torch.Tensor, - image_end: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - ) -> IO.NodeOutput: - validate_string(prompt_text, field_name="prompt_text", min_length=1) - pika_files = [ - ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), - ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), - ] - initial_operation = await sync_op( - cls, - ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"), - response_model=pika_defs.PikaGenerateResponse, - data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - ), - files=pika_files, - content_type="multipart/form-data", - ) - return await execute_task(initial_operation.video_id, cls) - - -class PikaApiNodesExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - PikaImageToVideo, - PikaTextToVideoNode, - PikaScenes, - PikAdditionsNode, - PikaSwapsNode, - PikaffectsNode, - PikaStartEndFrameNode, - ] - - -async def comfy_entrypoint() -> PikaApiNodesExtension: - return PikaApiNodesExtension() diff --git a/nodes.py b/nodes.py index 8678f510a..3fa543294 100644 --- a/nodes.py +++ b/nodes.py @@ -2384,7 +2384,6 @@ async def init_builtin_api_nodes(): "nodes_recraft.py", "nodes_pixverse.py", "nodes_stability.py", - "nodes_pika.py", "nodes_runway.py", "nodes_sora.py", "nodes_topaz.py", From dbd330454ada04609c69fda2ae7c002d7ea05f67 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 16 Dec 2025 08:57:39 +0900 Subject: [PATCH 043/148] feat(preview): add per-queue live preview method override (#11261) - Add set_preview_method() to override live preview method per queue item - Read extra_data.preview_method from /prompt request - Support values: taesd, latent2rgb, none, auto, default - "default" or unset uses server's CLI --preview-method setting - Add 44 tests (37 unit + 7 E2E) --- comfy/cli_args.py | 7 + execution.py | 3 + latent_preview.py | 10 + .../preview_method_override_test.py | 352 +++++++++++++++++ tests/execution/test_preview_method.py | 358 ++++++++++++++++++ 5 files changed, 730 insertions(+) create mode 100644 tests-unit/execution_test/preview_method_override_test.py create mode 100644 tests/execution/test_preview_method.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b..dae9a895d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum): Latent2RGB = "latent2rgb" TAESD = "taesd" + @classmethod + def from_string(cls, value: str): + for member in cls: + if member.value == value: + return member + return None + parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") diff --git a/execution.py b/execution.py index c2186ac98..0c239efd7 100644 --- a/execution.py +++ b/execution.py @@ -13,6 +13,7 @@ import asyncio import torch import comfy.model_management +from latent_preview import set_preview_method import nodes from comfy_execution.caching import ( BasicCache, @@ -669,6 +670,8 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + set_preview_method(extra_data.get("preview_method")) + nodes.interrupt_processing(False) if "client_id" in extra_data: diff --git a/latent_preview.py b/latent_preview.py index 66bded4b9..d52e3f7a1 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -8,6 +8,8 @@ import folder_paths import comfy.utils import logging +default_preview_method = args.preview_method + MAX_PREVIEW_RESOLUTION = args.preview_size VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] @@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None): pbar.update_absolute(step + 1, total_steps, preview_bytes) return callback +def set_preview_method(override: str = None): + if override and override != "default": + method = LatentPreviewMethod.from_string(override) + if method is not None: + args.preview_method = method + return + args.preview_method = default_preview_method + diff --git a/tests-unit/execution_test/preview_method_override_test.py b/tests-unit/execution_test/preview_method_override_test.py new file mode 100644 index 000000000..79432d610 --- /dev/null +++ b/tests-unit/execution_test/preview_method_override_test.py @@ -0,0 +1,352 @@ +""" +Unit tests for Queue-specific Preview Method Override feature. + +Tests the preview method override functionality: +- LatentPreviewMethod.from_string() method +- set_preview_method() function in latent_preview.py +- default_preview_method variable +- Integration with args.preview_method +""" +import pytest +from comfy.cli_args import args, LatentPreviewMethod +from latent_preview import set_preview_method, default_preview_method + + +class TestLatentPreviewMethodFromString: + """Test LatentPreviewMethod.from_string() classmethod.""" + + @pytest.mark.parametrize("value,expected", [ + ("auto", LatentPreviewMethod.Auto), + ("latent2rgb", LatentPreviewMethod.Latent2RGB), + ("taesd", LatentPreviewMethod.TAESD), + ("none", LatentPreviewMethod.NoPreviews), + ]) + def test_valid_values_return_enum(self, value, expected): + """Valid string values should return corresponding enum.""" + assert LatentPreviewMethod.from_string(value) == expected + + @pytest.mark.parametrize("invalid", [ + "invalid", + "TAESD", # Case sensitive + "AUTO", # Case sensitive + "Latent2RGB", # Case sensitive + "latent", + "", + "default", # default is special, not a method + ]) + def test_invalid_values_return_none(self, invalid): + """Invalid string values should return None.""" + assert LatentPreviewMethod.from_string(invalid) is None + + +class TestLatentPreviewMethodEnumValues: + """Test LatentPreviewMethod enum has expected values.""" + + def test_enum_values(self): + """Verify enum values match expected strings.""" + assert LatentPreviewMethod.NoPreviews.value == "none" + assert LatentPreviewMethod.Auto.value == "auto" + assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb" + assert LatentPreviewMethod.TAESD.value == "taesd" + + def test_enum_count(self): + """Verify exactly 4 preview methods exist.""" + assert len(LatentPreviewMethod) == 4 + + +class TestSetPreviewMethod: + """Test set_preview_method() function from latent_preview.py.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_override_with_taesd(self): + """'taesd' should set args.preview_method to TAESD.""" + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + def test_override_with_latent2rgb(self): + """'latent2rgb' should set args.preview_method to Latent2RGB.""" + set_preview_method("latent2rgb") + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + def test_override_with_auto(self): + """'auto' should set args.preview_method to Auto.""" + set_preview_method("auto") + assert args.preview_method == LatentPreviewMethod.Auto + + def test_override_with_none_value(self): + """'none' should set args.preview_method to NoPreviews.""" + set_preview_method("none") + assert args.preview_method == LatentPreviewMethod.NoPreviews + + def test_default_restores_original(self): + """'default' should restore to default_preview_method.""" + # First override to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then use 'default' to restore + set_preview_method("default") + assert args.preview_method == default_preview_method + + def test_none_param_restores_original(self): + """None parameter should restore to default_preview_method.""" + # First override to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then use None to restore + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_empty_string_restores_original(self): + """Empty string should restore to default_preview_method.""" + set_preview_method("taesd") + set_preview_method("") + assert args.preview_method == default_preview_method + + def test_invalid_value_restores_original(self): + """Invalid value should restore to default_preview_method.""" + set_preview_method("taesd") + set_preview_method("invalid_method") + assert args.preview_method == default_preview_method + + def test_case_sensitive_invalid_restores(self): + """Case-mismatched values should restore to default.""" + set_preview_method("taesd") + set_preview_method("TAESD") # Wrong case + assert args.preview_method == default_preview_method + + +class TestDefaultPreviewMethod: + """Test default_preview_method module variable.""" + + def test_default_is_not_none(self): + """default_preview_method should not be None.""" + assert default_preview_method is not None + + def test_default_is_enum_member(self): + """default_preview_method should be a LatentPreviewMethod enum.""" + assert isinstance(default_preview_method, LatentPreviewMethod) + + def test_default_matches_args_initial(self): + """default_preview_method should match CLI default or user setting.""" + # This tests that default_preview_method was captured at module load + # After set_preview_method(None), args should equal default + original = args.preview_method + set_preview_method("taesd") + set_preview_method(None) + assert args.preview_method == default_preview_method + args.preview_method = original + + +class TestArgsPreviewMethodModification: + """Test args.preview_method can be modified correctly.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_args_accepts_all_enum_values(self): + """args.preview_method should accept all LatentPreviewMethod values.""" + for method in LatentPreviewMethod: + args.preview_method = method + assert args.preview_method == method + + def test_args_modification_and_restoration(self): + """args.preview_method should be modifiable and restorable.""" + original = args.preview_method + + args.preview_method = LatentPreviewMethod.TAESD + assert args.preview_method == LatentPreviewMethod.TAESD + + args.preview_method = original + assert args.preview_method == original + + +class TestExecutionFlow: + """Test the execution flow pattern used in execution.py.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_sequential_executions_with_different_methods(self): + """Simulate multiple queue executions with different preview methods.""" + # Execution 1: taesd + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Execution 2: none + set_preview_method("none") + assert args.preview_method == LatentPreviewMethod.NoPreviews + + # Execution 3: default (restore) + set_preview_method("default") + assert args.preview_method == default_preview_method + + # Execution 4: auto + set_preview_method("auto") + assert args.preview_method == LatentPreviewMethod.Auto + + # Execution 5: no override (None) + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_override_then_default_pattern(self): + """Test the pattern: override -> execute -> next call restores.""" + # First execution with override + set_preview_method("latent2rgb") + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + # Second execution without override restores default + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_extra_data_simulation(self): + """Simulate extra_data.get('preview_method') patterns.""" + # Simulate: extra_data = {"preview_method": "taesd"} + extra_data = {"preview_method": "taesd"} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + # Simulate: extra_data = {} + extra_data = {} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + # Simulate: extra_data = {"preview_method": "default"} + extra_data = {"preview_method": "default"} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + +class TestRealWorldScenarios: + """Tests using real-world prompt data patterns.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_captured_prompt_without_preview_method(self): + """ + Test with captured prompt that has no preview_method. + Based on: tests-unit/execution_test/fixtures/default_prompt.json + """ + # Real captured extra_data structure (preview_method absent) + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "271314f0dabd48e5aaa488ed7a4ceb0d", + "create_time": 1765416558179 + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + def test_captured_prompt_with_preview_method_taesd(self): + """Test captured prompt with preview_method: taesd.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "271314f0dabd48e5aaa488ed7a4ceb0d", + "preview_method": "taesd" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + def test_captured_prompt_with_preview_method_none(self): + """Test captured prompt with preview_method: none (disable preview).""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "none" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.NoPreviews + + def test_captured_prompt_with_preview_method_latent2rgb(self): + """Test captured prompt with preview_method: latent2rgb.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "latent2rgb" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + def test_captured_prompt_with_preview_method_auto(self): + """Test captured prompt with preview_method: auto.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "auto" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Auto + + def test_captured_prompt_with_preview_method_default(self): + """Test captured prompt with preview_method: default (use CLI setting).""" + # First set to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then simulate a prompt with "default" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "default" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + def test_sequential_queue_with_different_preview_methods(self): + """ + Simulate real queue scenario: multiple prompts with different settings. + This tests the actual usage pattern in ComfyUI. + """ + # Queue 1: User wants TAESD preview + extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"} + set_preview_method(extra_data_1.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + # Queue 2: User wants no preview (faster execution) + extra_data_2 = {"client_id": "client-2", "preview_method": "none"} + set_preview_method(extra_data_2.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.NoPreviews + + # Queue 3: User doesn't specify (use server default) + extra_data_3 = {"client_id": "client-3"} + set_preview_method(extra_data_3.get("preview_method")) + assert args.preview_method == default_preview_method + + # Queue 4: User explicitly wants default + extra_data_4 = {"client_id": "client-4", "preview_method": "default"} + set_preview_method(extra_data_4.get("preview_method")) + assert args.preview_method == default_preview_method + + # Queue 5: User wants latent2rgb + extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"} + set_preview_method(extra_data_5.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Latent2RGB diff --git a/tests/execution/test_preview_method.py b/tests/execution/test_preview_method.py new file mode 100644 index 000000000..c3037553b --- /dev/null +++ b/tests/execution/test_preview_method.py @@ -0,0 +1,358 @@ +""" +E2E tests for Queue-specific Preview Method Override feature. + +Tests actual execution with different preview_method values. +Requires a running ComfyUI server with models. + +Usage: + COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method + +Note: + These tests execute actual image generation and wait for completion. + Tests verify preview image transmission based on preview_method setting. +""" +import os +import json +import pytest +import uuid +import time +import random +import websocket +import urllib.request +from pathlib import Path + + +# Server configuration +SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988") +SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "") + +# Use existing inference graph fixture +GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json" + + +def is_server_running() -> bool: + """Check if ComfyUI server is running.""" + try: + request = urllib.request.Request(f"{SERVER_URL}/system_stats") + with urllib.request.urlopen(request, timeout=2.0): + return True + except Exception: + return False + + +def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict: + """Prepare graph for testing: randomize seeds and reduce steps.""" + adapted = json.loads(json.dumps(graph)) # Deep copy + for node_id, node in adapted.items(): + inputs = node.get("inputs", {}) + # Handle both "seed" and "noise_seed" (used by KSamplerAdvanced) + if "seed" in inputs: + inputs["seed"] = random.randint(0, 2**32 - 1) + if "noise_seed" in inputs: + inputs["noise_seed"] = random.randint(0, 2**32 - 1) + # Reduce steps for faster testing (default 20 -> 5) + if "steps" in inputs: + inputs["steps"] = steps + return adapted + + +# Alias for backward compatibility +randomize_seed = prepare_graph_for_test + + +class PreviewMethodClient: + """Client for testing preview_method with WebSocket execution tracking.""" + + def __init__(self, server_address: str): + self.server_address = server_address + self.client_id = str(uuid.uuid4()) + self.ws = None + + def connect(self): + """Connect to WebSocket.""" + self.ws = websocket.WebSocket() + self.ws.settimeout(120) # 2 minute timeout for sampling + self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}") + + def close(self): + """Close WebSocket connection.""" + if self.ws: + self.ws.close() + + def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict: + """Queue a prompt and return response with prompt_id.""" + data = { + "prompt": prompt, + "client_id": self.client_id, + "extra_data": extra_data or {} + } + req = urllib.request.Request( + f"http://{self.server_address}/prompt", + data=json.dumps(data).encode("utf-8"), + headers={"Content-Type": "application/json"} + ) + return json.loads(urllib.request.urlopen(req).read()) + + def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict: + """ + Wait for execution to complete via WebSocket. + + Returns: + dict with keys: completed, error, preview_count, execution_time + """ + result = { + "completed": False, + "error": None, + "preview_count": 0, + "execution_time": 0.0 + } + + start_time = time.time() + self.ws.settimeout(timeout) + + try: + while True: + out = self.ws.recv() + elapsed = time.time() - start_time + + if isinstance(out, str): + message = json.loads(out) + msg_type = message.get("type") + data = message.get("data", {}) + + if data.get("prompt_id") != prompt_id: + continue + + if msg_type == "executing": + if data.get("node") is None: + # Execution complete + result["completed"] = True + result["execution_time"] = elapsed + break + + elif msg_type == "execution_error": + result["error"] = data + result["execution_time"] = elapsed + break + + elif msg_type == "progress": + # Progress update during sampling + pass + + elif isinstance(out, bytes): + # Binary data = preview image + result["preview_count"] += 1 + + except websocket.WebSocketTimeoutException: + result["error"] = "Timeout waiting for execution" + result["execution_time"] = time.time() - start_time + + return result + + +def load_graph() -> dict: + """Load the SDXL graph fixture with randomized seed.""" + with open(GRAPH_FILE) as f: + graph = json.load(f) + return randomize_seed(graph) # Avoid caching + + +# Skip all tests if server is not running +pytestmark = [ + pytest.mark.skipif( + not is_server_running(), + reason=f"ComfyUI server not running at {SERVER_URL}" + ), + pytest.mark.preview_method, + pytest.mark.execution, +] + + +@pytest.fixture +def client(): + """Create and connect a test client.""" + c = PreviewMethodClient(SERVER_HOST) + c.connect() + yield c + c.close() + + +@pytest.fixture +def graph(): + """Load the test graph.""" + return load_graph() + + +class TestPreviewMethodExecution: + """Test actual execution with different preview methods.""" + + def test_execution_with_latent2rgb(self, client, graph): + """ + Execute with preview_method=latent2rgb. + Should complete and potentially receive preview images. + """ + extra_data = {"preview_method": "latent2rgb"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + # Should complete (may error if model missing, but that's separate) + assert result["completed"] or result["error"] is not None + # Execution should take some time (sampling) + if result["completed"]: + assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run" + # latent2rgb should produce previews + print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_taesd(self, client, graph): + """ + Execute with preview_method=taesd. + TAESD provides higher quality previews. + """ + extra_data = {"preview_method": "taesd"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + assert result["execution_time"] > 0.5 + # taesd should also produce previews + print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_none_preview(self, client, graph): + """ + Execute with preview_method=none. + No preview images should be generated. + """ + extra_data = {"preview_method": "none"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + # With "none", should receive no preview images + assert result["preview_count"] == 0, \ + f"Expected no previews with 'none', got {result['preview_count']}" + print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_default(self, client, graph): + """ + Execute with preview_method=default. + Should use server's CLI default setting. + """ + extra_data = {"preview_method": "default"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_without_preview_method(self, client, graph): + """ + Execute without preview_method in extra_data. + Should use server's default preview method. + """ + extra_data = {} # No preview_method + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + +class TestPreviewMethodComparison: + """Compare preview behavior between different methods.""" + + def test_none_vs_latent2rgb_preview_count(self, client, graph): + """ + Compare preview counts: 'none' should have 0, others should have >0. + This is the key verification that preview_method actually works. + """ + results = {} + + # Run with none (randomize seed to avoid caching) + graph_none = randomize_seed(graph) + extra_data_none = {"preview_method": "none"} + response = client.queue_prompt(graph_none, extra_data_none) + results["none"] = client.wait_for_execution(response["prompt_id"]) + + # Run with latent2rgb (randomize seed again) + graph_rgb = randomize_seed(graph) + extra_data_rgb = {"preview_method": "latent2rgb"} + response = client.queue_prompt(graph_rgb, extra_data_rgb) + results["latent2rgb"] = client.wait_for_execution(response["prompt_id"]) + + # Verify both completed + assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}" + assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}" + + # Key assertion: 'none' should have 0 previews + assert results["none"]["preview_count"] == 0, \ + f"'none' should have 0 previews, got {results['none']['preview_count']}" + + # 'latent2rgb' should have at least 1 preview (depends on steps) + assert results["latent2rgb"]["preview_count"] > 0, \ + f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}" + + print("\nPreview count comparison:") # noqa: T201 + print(f" none: {results['none']['preview_count']} previews") # noqa: T201 + print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201 + + +class TestPreviewMethodSequential: + """Test sequential execution with different preview methods.""" + + def test_sequential_different_methods(self, client, graph): + """ + Execute multiple prompts sequentially with different preview methods. + Each should complete independently with correct preview behavior. + """ + methods = ["latent2rgb", "none", "default"] + results = [] + + for method in methods: + # Randomize seed for each execution to avoid caching + graph_run = randomize_seed(graph) + extra_data = {"preview_method": method} + response = client.queue_prompt(graph_run, extra_data) + + result = client.wait_for_execution(response["prompt_id"]) + results.append({ + "method": method, + "completed": result["completed"], + "preview_count": result["preview_count"], + "execution_time": result["execution_time"], + "error": result["error"] + }) + + # All should complete or have clear errors + for r in results: + assert r["completed"] or r["error"] is not None, \ + f"Method {r['method']} neither completed nor errored" + + # "none" should have zero previews if completed + none_result = next(r for r in results if r["method"] == "none") + if none_result["completed"]: + assert none_result["preview_count"] == 0, \ + f"'none' should have 0 previews, got {none_result['preview_count']}" + + print("\nSequential execution results:") # noqa: T201 + for r in results: + status = "✓" if r["completed"] else f"✗ ({r['error']})" + print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201 From 43e0d4e3ccfe8b5eac81bcee6f912f661849aafb Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 16 Dec 2025 02:01:10 +0200 Subject: [PATCH 044/148] comfy_api: remove usage of "Type","List" and "Dict" types (#11238) --- comfy_api/feature_flags.py | 10 +++++----- comfy_api/internal/api_registry.py | 10 +++++----- comfy_api/internal/async_to_sync.py | 14 ++++++------- comfy_api/internal/singleton.py | 6 +++--- comfy_api/latest/__init__.py | 4 ++-- comfy_api/latest/_input/basic_types.py | 4 ++-- comfy_api/latest/_ui.py | 27 +++++++++++++------------- comfy_api/version_list.py | 3 +-- 8 files changed, 38 insertions(+), 40 deletions(-) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index bfb77eb5f..de167f037 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend, allowing graceful protocol evolution while maintaining backward compatibility. """ -from typing import Any, Dict +from typing import Any from comfy.cli_args import args # Default server capabilities -SERVER_FEATURE_FLAGS: Dict[str, Any] = { +SERVER_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, @@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = { def get_connection_feature( - sockets_metadata: Dict[str, Dict[str, Any]], + sockets_metadata: dict[str, dict[str, Any]], sid: str, feature_name: str, default: Any = False @@ -42,7 +42,7 @@ def get_connection_feature( def supports_feature( - sockets_metadata: Dict[str, Dict[str, Any]], + sockets_metadata: dict[str, dict[str, Any]], sid: str, feature_name: str ) -> bool: @@ -60,7 +60,7 @@ def supports_feature( return get_connection_feature(sockets_metadata, sid, feature_name, False) is True -def get_server_features() -> Dict[str, Any]: +def get_server_features() -> dict[str, Any]: """ Get the server's feature flags. diff --git a/comfy_api/internal/api_registry.py b/comfy_api/internal/api_registry.py index 7e3375cf6..2b1cb016a 100644 --- a/comfy_api/internal/api_registry.py +++ b/comfy_api/internal/api_registry.py @@ -1,4 +1,4 @@ -from typing import Type, List, NamedTuple +from typing import NamedTuple from comfy_api.internal.singleton import ProxiedSingleton from packaging import version as packaging_version @@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton): class ComfyAPIWithVersion(NamedTuple): version: str - api_class: Type[ComfyAPIBase] + api_class: type[ComfyAPIBase] def parse_version(version_str: str) -> packaging_version.Version: @@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version: return packaging_version.parse(version_str) -registered_versions: List[ComfyAPIWithVersion] = [] +registered_versions: list[ComfyAPIWithVersion] = [] -def register_versions(versions: List[ComfyAPIWithVersion]): +def register_versions(versions: list[ComfyAPIWithVersion]): versions.sort(key=lambda x: parse_version(x.version)) global registered_versions registered_versions = versions -def get_all_versions() -> List[ComfyAPIWithVersion]: +def get_all_versions() -> list[ComfyAPIWithVersion]: """ Returns a list of all registered ComfyAPI versions. """ diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py index 257ade82e..c9b0576e1 100644 --- a/comfy_api/internal/async_to_sync.py +++ b/comfy_api/internal/async_to_sync.py @@ -8,7 +8,7 @@ import os import textwrap import threading from enum import Enum -from typing import Optional, Type, get_origin, get_args, get_type_hints +from typing import Optional, get_origin, get_args, get_type_hints class TypeTracker: @@ -193,7 +193,7 @@ class AsyncToSyncConverter: return result_container["result"] @classmethod - def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: + def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type: """ Creates a new class with synchronous versions of all async methods. @@ -563,7 +563,7 @@ class AsyncToSyncConverter: @classmethod def _generate_imports( - cls, async_class: Type, type_tracker: TypeTracker + cls, async_class: type, type_tracker: TypeTracker ) -> list[str]: """Generate import statements for the stub file.""" imports = [] @@ -628,7 +628,7 @@ class AsyncToSyncConverter: return imports @classmethod - def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: + def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]: """Extract class attributes that are classes themselves.""" class_attributes = [] @@ -654,7 +654,7 @@ class AsyncToSyncConverter: def _generate_inner_class_stub( cls, name: str, - attr: Type, + attr: type, indent: str = " ", type_tracker: Optional[TypeTracker] = None, ) -> list[str]: @@ -782,7 +782,7 @@ class AsyncToSyncConverter: return processed @classmethod - def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: + def generate_stub_file(cls, async_class: type, sync_class: type) -> None: """ Generate a .pyi stub file for the sync class to help IDEs with type checking. """ @@ -988,7 +988,7 @@ class AsyncToSyncConverter: logging.error(traceback.format_exc()) -def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: +def create_sync_class(async_class: type, thread_pool_size=10) -> type: """ Creates a sync version of an async class diff --git a/comfy_api/internal/singleton.py b/comfy_api/internal/singleton.py index 75f16f98e..d89380262 100644 --- a/comfy_api/internal/singleton.py +++ b/comfy_api/internal/singleton.py @@ -1,4 +1,4 @@ -from typing import Type, TypeVar +from typing import TypeVar class SingletonMetaclass(type): T = TypeVar("T", bound="SingletonMetaclass") @@ -11,13 +11,13 @@ class SingletonMetaclass(type): ) return cls._instances[cls] - def inject_instance(cls: Type[T], instance: T) -> None: + def inject_instance(cls: type[T], instance: T) -> None: assert cls not in SingletonMetaclass._instances, ( "Cannot inject instance after first instantiation" ) SingletonMetaclass._instances[cls] = instance - def get_instance(cls: Type[T], *args, **kwargs) -> T: + def get_instance(cls: type[T], *args, **kwargs) -> T: """ Gets the singleton instance of the class, creating it if it doesn't exist. """ diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 35e1ac853..fab63c7df 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Type, TYPE_CHECKING +from typing import TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class @@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest if TYPE_CHECKING: import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore - ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] + ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync = create_sync_class(ComfyAPI_latest) # create new aliases for io and ui diff --git a/comfy_api/latest/_input/basic_types.py b/comfy_api/latest/_input/basic_types.py index 245c6cbb1..d73deabd2 100644 --- a/comfy_api/latest/_input/basic_types.py +++ b/comfy_api/latest/_input/basic_types.py @@ -1,5 +1,5 @@ import torch -from typing import TypedDict, List, Optional +from typing import TypedDict, Optional ImageInput = torch.Tensor """ @@ -39,4 +39,4 @@ class LatentInput(TypedDict): Optional noise mask tensor in the same format as samples. """ - batch_index: Optional[List[int]] + batch_index: Optional[list[int]] diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 2babe209a..e238cdf3c 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -5,7 +5,6 @@ import os import random import uuid from io import BytesIO -from typing import Type import av import numpy as np @@ -83,7 +82,7 @@ class ImageSaveHelper: return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) @staticmethod - def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None: """Creates a PngInfo object with prompt and extra_pnginfo.""" if args.disable_metadata or cls is None or not cls.hidden: return None @@ -96,7 +95,7 @@ class ImageSaveHelper: return metadata @staticmethod - def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None: """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" if args.disable_metadata or cls is None or not cls.hidden: return None @@ -121,7 +120,7 @@ class ImageSaveHelper: return metadata @staticmethod - def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: + def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif: """Creates EXIF metadata bytes for WebP images.""" exif_data = pil_image.getexif() if args.disable_metadata or cls is None or cls.hidden is None: @@ -137,7 +136,7 @@ class ImageSaveHelper: @staticmethod def save_images( - images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, + images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4, ) -> list[SavedResult]: """Saves a batch of images as individual PNG files.""" full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( @@ -155,7 +154,7 @@ class ImageSaveHelper: return results @staticmethod - def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: + def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages: """Saves a batch of images and returns a UI object for the node output.""" return SavedImages( ImageSaveHelper.save_images( @@ -169,7 +168,7 @@ class ImageSaveHelper: @staticmethod def save_animated_png( - images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int + images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int ) -> SavedResult: """Saves a batch of images as a single animated PNG.""" full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( @@ -191,7 +190,7 @@ class ImageSaveHelper: @staticmethod def get_save_animated_png_ui( - images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int + images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int ) -> SavedImages: """Saves an animated PNG and returns a UI object for the node output.""" result = ImageSaveHelper.save_animated_png( @@ -209,7 +208,7 @@ class ImageSaveHelper: images, filename_prefix: str, folder_type: FolderType, - cls: Type[ComfyNode] | None, + cls: type[ComfyNode] | None, fps: float, lossless: bool, quality: int, @@ -238,7 +237,7 @@ class ImageSaveHelper: def get_save_animated_webp_ui( images, filename_prefix: str, - cls: Type[ComfyNode] | None, + cls: type[ComfyNode] | None, fps: float, lossless: bool, quality: int, @@ -267,7 +266,7 @@ class AudioSaveHelper: audio: dict, filename_prefix: str, folder_type: FolderType, - cls: Type[ComfyNode] | None, + cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k", ) -> list[SavedResult]: @@ -372,7 +371,7 @@ class AudioSaveHelper: @staticmethod def get_save_audio_ui( - audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", + audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k", ) -> SavedAudios: """Save and instantly wrap for UI.""" return SavedAudios( @@ -388,7 +387,7 @@ class AudioSaveHelper: class PreviewImage(_UIOutput): - def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): + def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs): self.values = ImageSaveHelper.save_images( image, filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), @@ -412,7 +411,7 @@ class PreviewMask(PreviewImage): class PreviewAudio(_UIOutput): - def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): + def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs): self.values = AudioSaveHelper.save_audio( audio, filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), diff --git a/comfy_api/version_list.py b/comfy_api/version_list.py index 7cb1871d5..be6e1db66 100644 --- a/comfy_api/version_list.py +++ b/comfy_api/version_list.py @@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 from comfy_api.internal import ComfyAPIBase -from typing import List, Type -supported_versions: List[Type[ComfyAPIBase]] = [ +supported_versions: list[type[ComfyAPIBase]] = [ ComfyAPI_latest, ComfyAPIAdapter_v0_0_2, ComfyAPIAdapter_v0_0_1, From 77b2f7c228a0db6643bb7f29be4db0bff6799db2 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:06:32 -0700 Subject: [PATCH 045/148] Add context windows callback for custom cond handling (#11208) Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 5c412d1c2..2979b3ca1 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -87,6 +87,7 @@ class IndexListCallbacks: COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" EXECUTE_START = "execute_start" EXECUTE_CLEANUP = "execute_cleanup" + RESIZE_COND_ITEM = "resize_cond_item" def init_callbacks(self): return {} @@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item = cond_item.copy() # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): + # Allow callbacks to handle custom conditioning items + handled = False + for callback in comfy.patcher_extension.get_all_callbacks( + IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks + ): + result = callback(cond_key, cond_value, window, x_in, device, new_cond_item) + if result is not None: + new_cond_item[cond_key] = result + handled = True + break + if handled: + continue if isinstance(cond_value, torch.Tensor): if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): From 70541d4e7769c6c40eae6594e677355eacd181fe Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:20:34 -0800 Subject: [PATCH 046/148] Support the new qwen edit 2511 reference method. (#11340) index_timestep_zero can be selected in the FluxKontextMultiReferenceLatentMethod now with the display name set to the more generic "Edit Model Reference Method" node. --- comfy/ldm/qwen_image/model.py | 47 +++++++++++++++++++++++++++++------ comfy_extras/nodes_flux.py | 3 ++- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..96590088b 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module): operations=operations, ) - def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _apply_gate(self, x, y, gate, timestep_zero_index=None): + if timestep_zero_index is not None: + return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1) + else: + return torch.addcmul(y, gate, x) + + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) - return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) + if timestep_zero_index is not None: + actual_batch = shift.size(0) // 2 + shift, shift_0 = shift[:actual_batch], shift[actual_batch:] + scale, scale_0 = scale[:actual_batch], scale[actual_batch:] + gate, gate_0 = gate[:actual_batch], gate[actual_batch:] + reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1)) + zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1)) + return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1)) + else: + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_zero_index=None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) + + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + txt_mod_params = self.txt_mod(temb) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index) del img_mod1 txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) del txt_mod1 @@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module): del img_modulated del txt_modulated - hidden_states = hidden_states + img_gate1 * img_attn_output + hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index) encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output del img_attn_output del txt_attn_output del img_gate1 del txt_gate1 - img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) - hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index) + hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) @@ -391,11 +411,14 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + timestep_zero_index = None if ref_latents is not None: h = 0 w = 0 index = 0 - index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + ref_method = kwargs.get("ref_latents_method", "index") + index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 @@ -415,6 +438,10 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = num_embeds txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) @@ -446,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] @@ -458,6 +485,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, ) @@ -474,6 +502,9 @@ class QwenImageTransformer2DModel(nn.Module): if add is not None: hidden_states[:, :add.shape[1]] += add + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index d9c4bba81..12c8ed3e6 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FluxKontextMultiReferenceLatentMethod", + display_name="Edit Model Reference Method", category="advanced/conditioning/flux", inputs=[ io.Conditioning.Input("conditioning"), io.Combo.Input( "reference_latents_method", - options=["offset", "index", "uxo/uno"], + options=["offset", "index", "uxo/uno", "index_timestep_zero"], ), ], outputs=[ From d02d0e5744f2e06fc40834d3c5bb387de4532007 Mon Sep 17 00:00:00 2001 From: seed93 Date: Tue, 16 Dec 2025 09:38:46 +0800 Subject: [PATCH 047/148] [add] tripo3.0 (#10663) * [add] tripo3.0 * [tripo] change paramter order * change order --------- Co-authored-by: liangd --- comfy_api_nodes/apis/tripo_api.py | 46 ++++++++++++++--- comfy_api_nodes/nodes_tripo.py | 86 ++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 713260e2a..ffaaa7dc1 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel class TripoModelVersion(str, Enum): + v3_0_20250812 = 'v3.0-20250812' v2_5_20250123 = 'v2.5-20250123' v2_0_20240919 = 'v2.0-20240919' v1_4_20240625 = 'v1.4-20240625' +class TripoGeometryQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoTextureQuality(str, Enum): standard = 'standard' detailed = 'detailed' @@ -61,14 +67,20 @@ class TripoSpec(str, Enum): class TripoAnimation(str, Enum): IDLE = "preset:idle" WALK = "preset:walk" + RUN = "preset:run" + DIVE = "preset:dive" CLIMB = "preset:climb" JUMP = "preset:jump" - RUN = "preset:run" SLASH = "preset:slash" SHOOT = "preset:shoot" HURT = "preset:hurt" FALL = "preset:fall" TURN = "preset:turn" + QUADRUPED_WALK = "preset:quadruped:walk" + HEXAPOD_WALK = "preset:hexapod:walk" + OCTOPOD_WALK = "preset:octopod:walk" + SERPENTINE_MARCH = "preset:serpentine:march" + AQUATIC_MARCH = "preset:aquatic:march" class TripoStylizeStyle(str, Enum): LEGO = "lego" @@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum): BANNED = "banned" EXPIRED = "expired" +class TripoFbxPreset(str, Enum): + BLENDER = "blender" + MIXAMO = "mixamo" + _3DSMAX = "3dsmax" + class TripoFileTokenReference(BaseModel): type: Optional[str] = Field(None, description='The type of the reference') file_token: str @@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard style: Optional[TripoStyle] = None auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') @@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') @@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') @@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel): type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') format: TripoConvertFormat = Field(..., description='The format to convert to') original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') - face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(4096, description='The size of the texture') + quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') + force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') + face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') + flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') + flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') + texture_size: Optional[int] = Field(None, description='The size of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') + pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') + scale_factor: Optional[float] = Field(None, description='The scale factor for the model') + with_animation: Optional[bool] = Field(None, description='Whether to include animations') + pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') + bake: Optional[bool] = Field(None, description='Whether to bake the model') + part_names: Optional[List[str]] = Field(None, description='The names of the parts to include') + fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') + export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') + export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') + animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') + class TripoTaskRequest(RootModel): root: Union[ diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 697100ff2..bd3c24fb3 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Int.Input("model_seed", default=42, optional=True), IO.Int.Input("texture_seed", default=42, optional=True), IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed: Optional[int] = None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, ) -> IO.NodeOutput: @@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode): texture_seed=texture_seed, texture_quality=texture_quality, face_limit=face_limit, + geometry_quality=geometry_quality, auto_size=True, quad=quad, ), @@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode): ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode): orientation=None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, texture_alignment: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, @@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode): pbr=pbr, model_seed=model_seed, orientation=orientation, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, @@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): model_seed: Optional[int] = None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, texture_alignment: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, @@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, face_limit=face_limit, quad=quad, @@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode): options=[ "preset:idle", "preset:walk", + "preset:run", + "preset:dive", "preset:climb", "preset:jump", "preset:slash", @@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode): "preset:hurt", "preset:fall", "preset:turn", + "preset:quadruped:walk", + "preset:hexapod:walk", + "preset:octopod:walk", + "preset:serpentine:march", + "preset:aquatic:march" ], ), ], @@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode): "face_limit", default=-1, min=-1, - max=500000, + max=2000000, optional=True, ), IO.Int.Input( @@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode): default="JPEG", optional=True, ), + IO.Boolean.Input("force_symmetry", default=False, optional=True), + IO.Boolean.Input("flatten_bottom", default=False, optional=True), + IO.Float.Input( + "flatten_bottom_threshold", + default=0.0, + min=0.0, + max=1.0, + optional=True, + ), + IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True), + IO.Float.Input( + "scale_factor", + default=1.0, + min=0.0, + optional=True, + ), + IO.Boolean.Input("with_animation", default=False, optional=True), + IO.Boolean.Input("pack_uv", default=False, optional=True), + IO.Boolean.Input("bake", default=False, optional=True), + IO.String.Input("part_names", default="", optional=True), # comma-separated list + IO.Combo.Input( + "fbx_preset", + options=["blender", "mixamo", "3dsmax"], + default="blender", + optional=True, + ), + IO.Boolean.Input("export_vertex_colors", default=False, optional=True), + IO.Combo.Input( + "export_orientation", + options=["align_image", "default"], + default="default", + optional=True, + ), + IO.Boolean.Input("animate_in_place", default=False, optional=True), ], outputs=[], hidden=[ @@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode): original_model_task_id, format: str, quad: bool, + force_symmetry: bool, face_limit: int, + flatten_bottom: bool, + flatten_bottom_threshold: float, texture_size: int, texture_format: str, + pivot_to_center_bottom: bool, + scale_factor: float, + with_animation: bool, + pack_uv: bool, + bake: bool, + part_names: str, + fbx_preset: str, + export_vertex_colors: bool, + export_orientation: str, + animate_in_place: bool, ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") + + # Parse part_names from comma-separated string to list + part_names_list = None + if part_names and part_names.strip(): + part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] + response = await sync_op( cls, endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), @@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode): original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, + force_symmetry=force_symmetry if force_symmetry else None, face_limit=face_limit if face_limit != -1 else None, + flatten_bottom=flatten_bottom if flatten_bottom else None, + flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None, texture_size=texture_size if texture_size != 4096 else None, texture_format=texture_format if texture_format != "JPEG" else None, + pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None, + scale_factor=scale_factor if scale_factor != 1.0 else None, + with_animation=with_animation if with_animation else None, + pack_uv=pack_uv if pack_uv else None, + bake=bake if bake else None, + part_names=part_names_list, + fbx_preset=fbx_preset if fbx_preset != "blender" else None, + export_vertex_colors=export_vertex_colors if export_vertex_colors else None, + export_orientation=export_orientation if export_orientation != "default" else None, + animate_in_place=animate_in_place if animate_in_place else None, ), ) return await poll_until_finished(cls, response, average_duration=30) From 41bcf0619db87d443d468c9ddad4454bdbc1b084 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:51:06 -0800 Subject: [PATCH 048/148] Add code to detect if a z image fun controlnet is broken or not. (#11341) --- comfy_extras/nodes_model_patch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index ec0e790dc..fdd5d0d3f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -248,7 +248,10 @@ class ModelPatchLoader: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 config['refiner_control'] = True - config['broken'] = True + ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None) + if ref_weight is not None: + if torch.count_nonzero(ref_weight) == 0: + config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) model.load_state_dict(sd) From fc4af8606880be0374cf1f1f45bc5730e6d22bf5 Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 16 Dec 2025 09:57:28 +0800 Subject: [PATCH 049/148] [BlockInfo] Lumina (#11227) * block info * device * Make tensor int again --------- Co-authored-by: Jedrzej Kosinski --- comfy/ldm/lumina/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 96cb37fa6..5628e2ba3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -634,8 +634,11 @@ class NextDiT(nn.Module): img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" img_input = img for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: From ea2c117bc3c9d3b38d68e651905ed0d6dd682f92 Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 16 Dec 2025 09:59:16 +0800 Subject: [PATCH 050/148] [BlockInfo] Wan (#10845) * block info * animate * tensor * device * revert --- comfy/ldm/wan/model.py | 21 ++++++++++++++++++--- comfy/ldm/wan/model_animate.py | 3 +++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a9d5e10d9..4216ce831 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -568,7 +568,10 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -763,7 +766,10 @@ class VaceWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -862,7 +868,10 @@ class CameraWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options) if audio_emb is not None: x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) # head @@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..84d7adec4 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -523,7 +523,10 @@ class AnimateWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From 683569de5527379d9a095af88a9e1349fb7e46b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:33:27 -0800 Subject: [PATCH 051/148] Only enable fp16 on ZImage on newer pytorch. (#11344) --- comfy/supported_models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 834dfcffc..1888f35ba 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -28,6 +28,7 @@ from . import supported_models_base from . import latent_formats from . import diffusers_convert +import comfy.model_management class SD15(supported_models_base.BASE): unet_config = { @@ -1028,7 +1029,13 @@ class ZImage(Lumina2): memory_usage_factor = 2.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + if comfy.model_management.extended_fp16_support(): + self.supported_inference_dtypes = self.supported_inference_dtypes.copy() + self.supported_inference_dtypes.insert(1, torch.float16) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] From 3d082c32065e0653490b9a4ae45dd33b6c7bffb7 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 15 Dec 2025 20:35:37 -0800 Subject: [PATCH 052/148] bump comfyui-frontend-package to 1.34.9 (patch) (#11342) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 117260515..9b9e61683 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.34.8 +comfyui-frontend-package==1.34.9 comfyui-workflow-templates==0.7.59 comfyui-embedded-docs==0.3.1 torch From 645ee1881e739b3013eeb26dbb335280bfbf443e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 20:38:12 -0800 Subject: [PATCH 053/148] Inpainting for z image fun control. Use the ZImageFunControlnet node. (#11346) image -> control image ex: pose inpaint_image -> image for inpainting mask -> inpaint mask --- comfy_extras/nodes_model_patch.py | 77 ++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index fdd5d0d3f..2a0cfcf18 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -313,22 +313,46 @@ class ZImageControlPatch: self.inpaint_image = inpaint_image self.mask = mask self.strength = strength - self.encoded_image = self.encode_latent_cond(image) - self.encoded_image_size = (image.shape[1], image.shape[2]) + self.is_inpaint = self.model_patch.model.additional_in_dim > 0 + + skip_encoding = False + if self.image is not None and self.inpaint_image is not None: + if self.image.shape != self.inpaint_image.shape: + skip_encoding = True + + if skip_encoding: + self.encoded_image = None + else: + self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image) + if self.image is None: + self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2]) + else: + self.encoded_image_size = (self.image.shape[1], self.image.shape[2]) self.temp_data = None - def encode_latent_cond(self, control_image, inpaint_image=None): - latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) - if self.model_patch.model.additional_in_dim > 0: - if self.mask is None: - mask_ = torch.zeros_like(latent_image)[:, :1] - else: - mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + def encode_latent_cond(self, control_image=None, inpaint_image=None): + latent_image = None + if control_image is not None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) + + if self.is_inpaint: if inpaint_image is None: inpaint_image = torch.ones_like(control_image) * 0.5 + if self.mask is not None: + mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center") + inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5 + inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image)) + if self.mask is None: + mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] + else: + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + + if latent_image is None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) + return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1) else: return latent_image @@ -344,13 +368,18 @@ class ZImageControlPatch: block_type = kwargs.get("block_type", "") spacial_compression = self.vae.spacial_compression_encode() if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): - image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + image_scaled = None + if self.image is not None: + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2]) + inpaint_scaled = None if self.inpaint_image is not None: inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2]) + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled) - self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled) comfy.model_management.load_models_gpu(loaded_models) cnet_blocks = self.model_patch.model.n_control_layers @@ -391,7 +420,8 @@ class ZImageControlPatch: def to(self, device_or_dtype): if isinstance(device_or_dtype, torch.device): - self.encoded_image = self.encoded_image.to(device_or_dtype) + if self.encoded_image is not None: + self.encoded_image = self.encoded_image.to(device_or_dtype) self.temp_data = None return self @@ -414,9 +444,12 @@ class QwenImageDiffsynthControlnet: CATEGORY = "advanced/loaders/qwen" - def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): + def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None): model_patched = model.clone() - image = image[:, :, :, :3] + if image is not None: + image = image[:, :, :, :3] + if inpaint_image is not None: + inpaint_image = inpaint_image[:, :, :, :3] if mask is not None: if mask.ndim == 3: mask = mask.unsqueeze(1) @@ -425,13 +458,24 @@ class QwenImageDiffsynthControlnet: mask = 1.0 - mask if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): - patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask) + patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask) model_patched.set_model_noise_refiner_patch(patch) model_patched.set_model_double_block_patch(patch) else: model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) +class ZImageFunControlnet(QwenImageDiffsynthControlnet): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }, + "optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}} + + CATEGORY = "advanced/loaders/zimage" class UsoStyleProjectorPatch: def __init__(self, model_patch, encoded_image): @@ -479,5 +523,6 @@ class USOStyleReference: NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, } From bc606d7d645f9edfcac7cca3558210d3ee391d94 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Dec 2025 22:26:55 -0800 Subject: [PATCH 054/148] Add a way to set the default ref method in the qwen image code. (#11349) --- comfy/ldm/qwen_image/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 96590088b..8481f7711 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -322,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module): pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + default_ref_method="index", image_model=None, final_layer=True, dtype=None, @@ -334,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module): self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.default_ref_method = default_ref_method self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) @@ -416,7 +418,7 @@ class QwenImageTransformer2DModel(nn.Module): h = 0 w = 0 index = 0 - ref_method = kwargs.get("ref_latents_method", "index") + ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: From 9304e47351be8d178a093b30bcaf5d72c3a2baf5 Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Mon, 15 Dec 2025 23:24:18 -0800 Subject: [PATCH 055/148] Update workflows for new release process (#11064) * Update release workflows for branch process * Adjust branch order in workflow triggers * Revert changes in test workflows --- .github/workflows/test-ci.yml | 1 + .github/workflows/test-execution.yml | 4 ++-- .github/workflows/test-launch.yml | 4 ++-- .github/workflows/test-unit.yml | 4 ++-- .github/workflows/update-version.yml | 1 + 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 1660ec8e3..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -5,6 +5,7 @@ on: push: branches: - master + - release/** paths-ignore: - 'app/**' - 'input/**' diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml index 00ef07ebf..9012633d8 100644 --- a/.github/workflows/test-execution.yml +++ b/.github/workflows/test-execution.yml @@ -2,9 +2,9 @@ name: Execution Tests on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index 1735fd83b..fd70aff23 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -2,9 +2,9 @@ name: Test server launches without errors on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 00caf5b8a..d05179cd3 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -2,9 +2,9 @@ name: Unit Tests on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml index d9d488974..c2343cc39 100644 --- a/.github/workflows/update-version.yml +++ b/.github/workflows/update-version.yml @@ -6,6 +6,7 @@ on: - "pyproject.toml" branches: - master + - release/** jobs: update-version: From 65e2103b09f66e45438445fb0e99709ae7639869 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:51:48 +0200 Subject: [PATCH 056/148] feat(api-nodes): add Wan2.6 model to video nodes (#11357) --- comfy_api_nodes/nodes_wan.py | 162 ++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 67 deletions(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 2aab3c2ff..17b680e13 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,7 +1,5 @@ import re -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override @@ -21,26 +19,26 @@ from comfy_api_nodes.util import ( class Text2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) class Image2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) images: list[str] = Field(..., min_length=1, max_length=2) class Text2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) - audio_url: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) + audio_url: str | None = Field(None) class Image2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) img_url: str = Field(...) - audio_url: Optional[str] = Field(None) + audio_url: str | None = Field(None) class Txt2ImageParametersField(BaseModel): @@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel): class Image2ImageParametersField(BaseModel): - size: Optional[str] = Field(None) + size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) watermark: bool = Field(True) @@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel): class Text2VideoParametersField(BaseModel): size: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Image2VideoParametersField(BaseModel): resolution: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Text2ImageTaskCreationRequest(BaseModel): @@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel): class TaskCreationResponse(BaseModel): - output: Optional[TaskCreationOutputField] = Field(None) + output: TaskCreationOutputField | None = Field(None) request_id: str = Field(...) - code: Optional[str] = Field(None, description="The error code of the failed request.") - message: Optional[str] = Field(None, description="Details of the failed request.") + code: str | None = Field(None, description="Error code for the failed request.") + message: str | None = Field(None, description="Details about the failed request.") class TaskResult(BaseModel): - url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - results: Optional[list[TaskResult]] = Field(None) + results: list[TaskResult] | None = Field(None) class VideoTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - video_url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + video_url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusResponse(BaseModel): - output: Optional[ImageTaskStatusOutputField] = Field(None) + output: ImageTaskStatusOutputField | None = Field(None) request_id: str = Field(...) class VideoTaskStatusResponse(BaseModel): - output: Optional[VideoTaskStatusOutputField] = Field(None) + output: VideoTaskStatusOutputField | None = Field(None) request_id: str = Field(...) @@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode): node_id="WanTextToImageApi", display_name="Wan Text to Image", category="api node/image/Wan", - description="Generates image based on text prompt.", + description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( "model", @@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Int.Input( @@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Image.Input( "image", - tooltip="Single-image editing or multi-image fusion, maximum 2 images.", + tooltip="Single-image editing or multi-image fusion. Maximum 2 images.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), # redo this later as an optional combo of recommended resolutions @@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", # width: int = 1024, @@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode): ): n_images = get_number_of_images(image) if n_images not in (1, 2): - raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") + raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.") images = [] for i in image: images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) @@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode): node_id="WanTextToVideoApi", display_name="Wan Text to Video", category="api node/video/Wan", - description="Generates video based on text prompt.", + description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-t2v-preview"], - default="wan2.5-t2v-preview", + options=["wan2.5-t2v-preview", "wan2.6-t2v"], + default="wan2.6-t2v", tooltip="Model to use.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode): "1080p: 4:3 (1632x1248)", "1080p: 3:4 (1248x1632)", ], - default="480p: 1:1 (624x624)", + default="720p: 1:1 (960x960)", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="A 15-second duration is available only for the Wan 2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode): model: str, prompt: str, negative_prompt: str = "", - size: str = "480p: 1:1 (624x624)", + size: str = "720p: 1:1 (960x960)", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): + if "480p" in size and model == "wan2.6-t2v": + raise ValueError("The Wan 2.6 model does not support 480p.") + if duration == 15 and model == "wan2.5-t2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") width, height = RES_IN_PARENS.search(size).groups() audio_url = None if audio is not None: @@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode): node_id="WanImageToVideoApi", display_name="Wan Image to Video", category="api node/video/Wan", - description="Generates video based on the first frame and text prompt.", + description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-i2v-preview"], - default="wan2.5-i2v-preview", + options=["wan2.5-i2v-preview", "wan2.6-i2v"], + default="wan2.6-i2v", tooltip="Model to use.", ), IO.Image.Input( @@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode): "720P", "1080P", ], - default="480P", + default="720P", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="Duration 15 available only for WAN2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", - resolution: str = "480P", + resolution: str = "720P", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") + if "480P" in resolution and model == "wan2.6-i2v": + raise ValueError("The Wan 2.6 model does not support 480P.") + if duration == 15 and model == "wan2.5-i2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: @@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), From ffdd53b327f7ebd48cf81a1c8b06d846cf354a66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:03:17 -0800 Subject: [PATCH 057/148] Check state dict key to auto enable the index_timestep_zero ref method. (#11362) --- comfy/ldm/qwen_image/model.py | 3 +++ comfy/model_detection.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8481f7711..902af30ed 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -363,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) + if self.default_ref_method == "index_timestep_zero": + self.register_buffer("__index_timestep_zero__", torch.tensor([])) + if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dd6a703f6..7148c77fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 - if "__x0__" in state_dict_keys: # x0 pred + if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred dit_config["use_x0"] = True else: dit_config["use_x0"] = False @@ -618,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "qwen_image" dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 + dit_config["default_ref_method"] = "index_timestep_zero" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 From 827bb1512b17e349238e69b2d4f463390a5b0d14 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:35:43 +0800 Subject: [PATCH 058/148] Add exp_heun_2_x0 sampler series (#11360) --- comfy/k_diffusion/sampling.py | 11 +++++++++++ comfy/samplers.py | 2 +- comfy_extras/nodes_custom_sampler.py | 11 ++++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 753c66afa..c004b3b47 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1618,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non x = x + sde_noise * sigmas[i + 1] * s_noise return x +@torch.no_grad() +def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"): + """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type) + + +@torch.no_grad() +def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"): + """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type) + @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): diff --git a/comfy/samplers.py b/comfy/samplers.py index fa4640842..8340d376c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -720,7 +720,7 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 71ea4e9ec..7ee4caac1 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -671,7 +671,16 @@ class SamplerSEEDS2(io.ComfyNode): io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), ], - outputs=[io.Sampler.Output()] + outputs=[io.Sampler.Output()], + description=( + "This sampler node can represent multiple samplers:\n\n" + "seeds_2\n" + "- default setting\n\n" + "exp_heun_2_x0\n" + "- solver_type=phi_2, r=1.0, eta=0.0\n\n" + "exp_heun_2_x0_sde\n" + "- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0" + ) ) @classmethod From 3a5f239cb622d7d8b1706d0b63c469dfef2eaf73 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 Dec 2025 03:46:11 -0500 Subject: [PATCH 059/148] ComfyUI v0.5.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 2f083edaf..5edf270e7 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/pyproject.toml b/pyproject.toml index e4d3d616a..c402f278c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.4.0" +version = "0.5.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 887143854bb2ae1e0f975e4461f376844a1628c8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:43:41 +0200 Subject: [PATCH 060/148] feat(api-nodes): add GPT-Image-1.5 (#11368) --- comfy_api_nodes/apis/openai_api.py | 52 +++++++ comfy_api_nodes/nodes_openai.py | 209 +++++++++++++++------------- comfy_api_nodes/util/conversions.py | 2 +- 3 files changed, 168 insertions(+), 95 deletions(-) create mode 100644 comfy_api_nodes/apis/openai_api.py diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py new file mode 100644 index 000000000..ae5bb2673 --- /dev/null +++ b/comfy_api_nodes/apis/openai_api.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = None + text_tokens: int | None = None + + +class Usage(BaseModel): + input_tokens: int | None = None + input_tokens_details: InputTokensDetails | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = None + usage: Usage | None = None + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c8da5464b..a6205a34f 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,46 +1,45 @@ -from io import BytesIO +import base64 import os from enum import Enum -from inspect import cleandoc +from io import BytesIO + import numpy as np import torch from PIL import Image -import folder_paths -import base64 -from comfy_api.latest import IO, ComfyExtension from typing_extensions import override - +import folder_paths +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( - OpenAIImageGenerationRequest, - OpenAIImageEditRequest, - OpenAIImageGenerationResponse, - OpenAICreateResponse, - OpenAIResponse, CreateModelResponseProperties, - Item, - OutputContent, - InputImageContent, Detail, - InputTextContent, - InputMessage, - InputMessageContentList, InputContent, InputFileContent, + InputImageContent, + InputMessage, + InputMessageContentList, + InputTextContent, + Item, + OpenAICreateResponse, + OpenAIResponse, + OutputContent, +) +from comfy_api_nodes.apis.openai_api import ( + OpenAIImageEditRequest, + OpenAIImageGenerationRequest, + OpenAIImageGenerationResponse, ) - from comfy_api_nodes.util import ( - downscale_image_tensor, - download_url_to_bytesio, - validate_string, - tensor_to_base64_string, ApiEndpoint, - sync_op, + download_url_to_bytesio, + downscale_image_tensor, poll_op, + sync_op, + tensor_to_base64_string, text_filepath_to_data_uri, + validate_string, ) - RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" @@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten class OpenAIDalle2(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 2 endpoint. - """ @classmethod def define_schema(cls): @@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode): node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode): class OpenAIDalle3(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 3 endpoint. - """ @classmethod def define_schema(cls): @@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode): node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode): return IO.NodeOutput(await validate_and_cast_response(response)) +def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None: + # https://platform.openai.com/docs/pricing + return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0 + + +def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's GPT Image 1 endpoint. - """ @classmethod def define_schema(cls): @@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode): node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 1", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", inputs=[ IO.String.Input( "prompt", default="", multiline=True, - tooltip="Text prompt for GPT Image 1", + tooltip="Text prompt for GPT Image", ), IO.Int.Input( "seed", @@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "background", - default="opaque", - options=["opaque", "transparent"], + default="auto", + options=["auto", "opaque", "transparent"], tooltip="Return image with or without background", optional=True, ), @@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode): tooltip="Optional mask for inpainting (white areas will be replaced)", optional=True, ), + IO.Combo.Input( + "model", + options=["gpt-image-1", "gpt-image-1.5"], + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode): @classmethod async def execute( cls, - prompt, - seed=0, - quality="low", - background="opaque", - image=None, - mask=None, - n=1, - size="1024x1024", + prompt: str, + seed: int = 0, + quality: str = "low", + background: str = "opaque", + image: Input.Image | None = None, + mask: Input.Image | None = None, + n: int = 1, + size: str = "1024x1024", + model: str = "gpt-image-1", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - model = "gpt-image-1" - path = "/proxy/openai/images/generations" - content_type = "application/json" - request_class = OpenAIImageGenerationRequest - files = [] + + if mask is not None and image is None: + raise ValueError("Cannot use a mask without an input image") + + if model == "gpt-image-1": + price_extractor = calculate_tokens_price_image_1 + elif model == "gpt-image-1.5": + price_extractor = calculate_tokens_price_image_1_5 + else: + raise ValueError(f"Unknown model: {model}") if image is not None: - path = "/proxy/openai/images/edits" - request_class = OpenAIImageEditRequest - content_type = "multipart/form-data" - + files = [] batch_size = image.shape[0] - for i in range(batch_size): - single_image = image[i : i + 1] - scaled_image = downscale_image_tensor(single_image).squeeze() + single_image = image[i: i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode): else: files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) - if mask is not None: - if image is None: - raise Exception("Cannot use a mask without an input image") - if image.shape[0] != 1: - raise Exception("Cannot use a mask with multiple image") - if mask.shape[1:] != image.shape[1:-1]: - raise Exception("Mask and Image must be the same size") - batch, height, width = mask.shape - rgba_mask = torch.zeros(height, width, 4, device="cpu") - rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() + if mask is not None: + if image.shape[0] != 1: + raise Exception("Cannot use a mask with multiple image") + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + _, height, width = mask.shape + rgba_mask = torch.zeros(height, width, 4, device="cpu") + rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() - mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) - mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = BytesIO() - mask_img.save(mask_img_byte_arr, format="PNG") - mask_img_byte_arr.seek(0) - files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) - - # Build the operation - response = await sync_op( - cls, - ApiEndpoint(path=path, method="POST"), - response_model=OpenAIImageGenerationResponse, - data=request_class( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - ), - files=files if files else None, - content_type=content_type, - ) + mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_img_byte_arr = BytesIO() + mask_img.save(mask_img_byte_arr, format="PNG") + mask_img_byte_arr.seek(0) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/edits", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageEditRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + content_type="multipart/form-data", + files=files, + price_extractor=price_extractor, + ) + else: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + price_extractor=price_extractor, + ) return IO.NodeOutput(await validate_and_cast_response(response)) diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index c57457580..d64239c86 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: +def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: """Downscale input image tensor to roughly the specified total pixels.""" samples = image.movedim(-1, 1) total = int(total_pixels) From c08f97f34407a1bc6cc8d1447d6c12893399acba Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:24:25 +0200 Subject: [PATCH 061/148] fix regression in V3 nodes processing (#11375) --- comfy_api/latest/_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 2b634d172..4b14e5ded 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1556,12 +1556,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) return type_clone @final From 5d9ad0c6bf177095aea5026cd872b1faf873669b Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Thu, 18 Dec 2025 02:57:40 +0800 Subject: [PATCH 062/148] Fix the last step with non-zero sigma in sa_solver (#11380) --- comfy/k_diffusion/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c004b3b47..1ba9edad7 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1776,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Predictor if sigmas[i + 1] == 0: # Denoising step - x = denoised + x_pred = denoised else: tau_t = tau_func(sigmas[i + 1]) curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] @@ -1797,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if tau_t > 0 and s_noise > 0: noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise x_pred = x_pred + noise - return x + return x_pred @torch.no_grad() From 16d85ea13342cebc8349a95236c94bde5ac3cb2a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:43:18 -0800 Subject: [PATCH 063/148] Better handle torch being imported by prestartup nodes. (#11383) --- main.py | 66 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/main.py b/main.py index 0d02a087b..0e07a95da 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,38 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if os.name == "nt": + os.environ['MIMALLOC_PURGE_DELAY'] = '0' + +if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' + if args.default_device is not None: + default_dev = args.default_device + devices = list(range(32)) + devices.remove(default_dev) + devices.insert(0, default_dev) + devices = ','.join(map(str, devices)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) + os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) + + if args.oneapi_device_selector is not None: + os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector + logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) + + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + def handle_comfyui_manager_unavailable(): if not args.windows_standalone_build: @@ -137,40 +169,6 @@ import shutil import threading import gc - -if os.name == "nt": - os.environ['MIMALLOC_PURGE_DELAY'] = '0' - -if __name__ == "__main__": - os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' - if args.default_device is not None: - default_dev = args.default_device - devices = list(range(32)) - devices.remove(default_dev) - devices.insert(0, default_dev) - devices = ','.join(map(str, devices)) - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) - os.environ['HIP_VISIBLE_DEVICES'] = str(devices) - - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) - logging.info("Set cuda device to: {}".format(args.cuda_device)) - - if args.oneapi_device_selector is not None: - os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector - logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - - import cuda_malloc - if "rocm" in cuda_malloc.get_torch_version_noimport(): - os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD - - if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") From ba6080bbab070934ea6e870c5fc30dbf702eb445 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 Dec 2025 21:04:50 -0500 Subject: [PATCH 064/148] ComfyUI v0.5.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5edf270e7..b45309198 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.0" +__version__ = "0.5.1" diff --git a/pyproject.toml b/pyproject.toml index c402f278c..3a6960811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.5.0" +version = "0.5.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 86dbb89fc95f0cb652ae5d6cb923f641a58e295d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 18 Dec 2025 11:15:27 +0800 Subject: [PATCH 065/148] Resolution bucketing and Trainer implementation refactoring (#11117) --- comfy/sampler_helpers.py | 9 +- comfy_extras/nodes_dataset.py | 96 ++- comfy_extras/nodes_post_processing.py | 11 +- comfy_extras/nodes_train.py | 854 +++++++++++++++++++------- 4 files changed, 738 insertions(+), 232 deletions(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index e46971afb..e158e8a84 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options) + return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + models_list = [model] if not skip_load_model else [] + comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) real_model = model.model return real_model, conds, models diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 4789d7d53..513aecf3a 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode): # ========== Training Dataset Nodes ========== +class ResolutionBucket(io.ComfyNode): + """Bucket latents and conditions by resolution for efficient batch training.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionBucket", + display_name="Resolution Bucket", + category="dataset", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts to bucket by resolution.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists (must match latents length).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of batched latent dicts, one per resolution bucket.", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of condition lists, one per resolution bucket.", + ), + ], + ) + + @classmethod + def execute(cls, latents, conditioning): + # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." + ) + + # Flatten latents and conditions to individual samples + flat_latents = [] # list of (C, H, W) tensors + flat_conditions = [] # list of condition lists + + for latent_dict, cond in zip(latents, conditioning): + samples = latent_dict["samples"] # (B, C, H, W) + batch_size = samples.shape[0] + + # cond is a list of conditions with length == batch_size + for i in range(batch_size): + flat_latents.append(samples[i]) # (C, H, W) + flat_conditions.append(cond[i]) # single condition + + # Group by resolution (H, W) + buckets = {} # (H, W) -> {"latents": list, "conditions": list} + + for latent, cond in zip(flat_latents, flat_conditions): + # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) + h, w = latent.shape[-2], latent.shape[-1] + key = (h, w) + + if key not in buckets: + buckets[key] = {"latents": [], "conditions": []} + + buckets[key]["latents"].append(latent) + buckets[key]["conditions"].append(cond) + + # Convert buckets to output format + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) + output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + + for (h, w), bucket_data in buckets.items(): + # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) + stacked_latents = torch.stack(bucket_data["latents"], dim=0) + output_latents.append({"samples": stacked_latents}) + + # Conditions stay as list of condition lists + output_conditions.append(bucket_data["conditions"]) + + logging.info( + f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" + ) + + logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + return io.NodeOutput(output_latents, output_conditions) + + class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" @@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f, weights_only=True) + shard_data = torch.load(f) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension): MakeTrainingDataset, SaveTrainingDataset, LoadTrainingDataset, + ResolutionBucket, ] diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 34c388a5a..ca2cdeb50 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.Int.Input("resolution_steps", default=1, min=1, max=256), ], outputs=[ io.Image.Output(), @@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode): ) @classmethod - def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: + def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput: samples = image.movedim(-1,1) - total = int(megapixels * 1024 * 1024) + total = megapixels * 1024 * 1024 scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) + width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps + height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled") s = s.movedim(1,-1) return io.NodeOutput(s) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 19b8baaf4..88bc8c8e8 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont from typing_extensions import override import comfy.samplers +import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management @@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar +class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): + """ + CFGGuider with modifications for training specific logic + """ + def outer_sample( + self, + noise, + latent_image, + sampler, + sigmas, + denoise_mask=None, + callback=None, + disable_pbar=False, + seed=None, + latent_shapes=None, + ): + self.inner_model, self.conds, self.loaded_models = ( + comfy.sampler_helpers.prepare_sampling( + self.model_patcher, + noise.shape, + self.conds, + self.model_options, + skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute() + ) + ) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask( + denoise_mask, noise.shape, device + ) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + comfy.samplers.cast_to_load_options( + self.model_options, device=device, dtype=self.model_patcher.model_dtype() + ) + + try: + self.model_patcher.pre_run() + output = self.inner_sample( + noise, + latent_image, + device, + sampler, + sigmas, + denoise_mask, + callback, + disable_pbar, + seed, + latent_shapes=latent_shapes, + ) + finally: + self.model_patcher.cleanup() + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.loaded_models + return output + + def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): @@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler): seed=0, training_dtype=torch.bfloat16, real_dataset=None, + bucket_latents=None, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler): self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset + # Bucket mode data + self.bucket_latents: list[torch.Tensor] | None = ( + bucket_latents # list of (Bi, C, Hi, Wi) + ) + # Precompute bucket offsets and weights for sampling + if bucket_latents is not None: + self._init_bucket_data(bucket_latents) + else: + self.bucket_offsets = None + self.bucket_weights = None + self.num_images = None + + def _init_bucket_data(self, bucket_latents): + """Initialize bucket offsets and weights for sampling.""" + self.bucket_offsets = [0] + bucket_sizes = [] + for lat in bucket_latents: + bucket_sizes.append(lat.shape[0]) + self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) + self.num_images = self.bucket_offsets[-1] + # Weights for sampling buckets proportional to their size + self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) def fwd_bwd( self, @@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler): bwd_loss.backward() return loss + def _generate_batch_sigmas(self, model_wrap, batch_size, device): + """Generate random sigma values for a batch.""" + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(batch_size) + ] + return torch.tensor(batch_sigmas).to(device) + + def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar): + """Execute one training step in bucket mode.""" + # Sample bucket (weighted by size), then sample batch from bucket + bucket_idx = torch.multinomial(self.bucket_weights, 1).item() + bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) + bucket_size = bucket_latent.shape[0] + bucket_offset = self.bucket_offsets[bucket_idx] + + # Sample indices from this bucket (use all if bucket_size < batch_size) + actual_batch_size = min(self.batch_size, bucket_size) + relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() + # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) + absolute_indices = [bucket_offset + idx for idx in relative_indices] + + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, # Use flattened cond with absolute indices + absolute_indices, + extra_args, + self.num_images, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) + + def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in standard (non-bucket, non-multi-res) mode.""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in multi-resolution mode (real_dataset is set).""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + def sample( self, model_wrap, @@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( self.seed + i * 1000 ) - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - if self.real_dataset is None: - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( - batch_latent.device - ) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - batch_latent, - cond, - indicies, - extra_args, - dataset_size, - bwd=True, - ) - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + if self.bucket_latents is not None: + self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar) + elif self.real_dataset is None: + self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) else: - total_loss = 0 - for index in indicies: - single_latent = self.real_dataset[index].to(latent_image) - batch_noise = noisegen.generate_noise( - {"samples": single_latent} - ).to(single_latent.device) - batch_sigmas = ( - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - ) - batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - single_latent, - cond, - [index], - extra_args, - dataset_size, - bwd=False, - ) - total_loss += loss - total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() - if self.loss_callback: - self.loss_callback(total_loss.item()) - pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() - ui_pbar.update(1) + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -283,6 +419,364 @@ def unpatch(m): del m.org_forward +def _process_latents_bucket_mode(latents): + """Process latents for bucket mode training. + + Args: + latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + + Returns: + list of latent tensors + """ + bucket_latents = [] + for latent_dict in latents: + bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) + return bucket_latents + + +def _process_latents_standard_mode(latents): + """Process latents for standard (non-bucket) mode training. + + Args: + latents: list of latent dicts or single latent dict + + Returns: + Processed latents (tensor or list of tensors) + """ + if len(latents) == 1: + return latents[0]["samples"] # Single latent dict + + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + return latent_list + + +def _process_conditioning(positive): + """Process conditioning - either single list or list of lists. + + Args: + positive: list of conditioning + + Returns: + Flattened conditioning list + """ + if len(positive) == 1: + return positive[0] # Single conditioning list + + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + return flat_positive + + +def _prepare_latents_and_count(latents, dtype, bucket_mode): + """Convert latents to dtype and compute image counts. + + Args: + latents: Latents (tensor, list of tensors, or bucket list) + dtype: Target dtype + bucket_mode: Whether bucket mode is enabled + + Returns: + tuple: (processed_latents, num_images, multi_res) + """ + if bucket_mode: + # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) + latents = [t.to(dtype) for t in latents] + num_buckets = len(latents) + num_images = sum(t.shape[0] for t in latents) + multi_res = False # Not using multi_res path in bucket mode + + logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + for i, lat in enumerate(latents): + logging.info(f" Bucket {i}: shape {lat.shape}") + return latents, num_images, multi_res + + # Non-bucket mode + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + multi_res = False + else: + logging.error(f"Invalid latents type: {type(latents)}") + num_images = 0 + multi_res = False + + return latents, num_images, multi_res + + +def _validate_and_expand_conditioning(positive, num_images, bucket_mode): + """Validate conditioning count matches image count, expand if needed. + + Args: + positive: Conditioning list + num_images: Number of images + bucket_mode: Whether bucket mode is enabled + + Returns: + Validated/expanded conditioning list + + Raises: + ValueError: If conditioning count doesn't match image count + """ + if bucket_mode: + return positive # Skip validation in bucket mode + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + return positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) + return positive + + +def _load_existing_lora(existing_lora): + """Load existing LoRA weights if provided. + + Args: + existing_lora: LoRA filename or "[None]" + + Returns: + tuple: (existing_weights dict, existing_steps int) + """ + if existing_lora == "[None]": + return {}, 0 + + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + existing_weights = {} + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + return existing_weights, existing_steps + + +def _create_weight_adapter( + module, module_name, existing_weights, algorithm, lora_dtype, rank +): + """Create a weight adapter for a module with weight. + + Args: + module: The module to create adapter for + module_name: Name of the module + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (train_adapter, lora_params dict) + """ + key = f"{module_name}.weight" + shape = module.weight.shape + lora_params = {} + + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) + + # Try to load existing adapter + existing_adapter = None + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + module_name, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + + if existing_adapter is None: + adapter_cls = adapter_maps[algorithm] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + module.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + + for name, parameter in train_adapter.named_parameters(): + lora_params[f"{module_name}.{name}"] = parameter + + return train_adapter.train().requires_grad_(True), lora_params + else: + # 1D weight - use BiasDiff + diff = torch.nn.Parameter( + torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) + ) + diff_module = BiasDiff(diff).train().requires_grad_(True) + lora_params[f"{module_name}.diff"] = diff + return diff_module, lora_params + + +def _create_bias_adapter(module, module_name, lora_dtype): + """Create a bias adapter for a module with bias. + + Args: + module: The module with bias + module_name: Name of the module + lora_dtype: dtype for LoRA weights + + Returns: + tuple: (bias_module, lora_params dict) + """ + bias = torch.nn.Parameter( + torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias).train().requires_grad_(True) + lora_params = {f"{module_name}.diff_b": bias} + return bias_module, lora_params + + +def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup all LoRA adapters on the model. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list) + """ + lora_sd = {} + all_weight_adapters = [] + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + key = f"{n}.weight" + mp.add_weight_wrapper(key, adapter) + all_weight_adapters.append(adapter) + + if hasattr(m, "bias") and m.bias is not None: + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + + return lora_sd, all_weight_adapters + + +def _create_optimizer(optimizer_name, parameters, learning_rate): + """Create optimizer based on name. + + Args: + optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop") + parameters: Parameters to optimize + learning_rate: Learning rate + + Returns: + Optimizer instance + """ + if optimizer_name == "Adam": + return torch.optim.Adam(parameters, lr=learning_rate) + elif optimizer_name == "AdamW": + return torch.optim.AdamW(parameters, lr=learning_rate) + elif optimizer_name == "SGD": + return torch.optim.SGD(parameters, lr=learning_rate) + elif optimizer_name == "RMSprop": + return torch.optim.RMSprop(parameters, lr=learning_rate) + + +def _create_loss_function(loss_function_name): + """Create loss function based on name. + + Args: + loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1") + + Returns: + Loss function instance + """ + if loss_function_name == "MSE": + return torch.nn.MSELoss() + elif loss_function_name == "L1": + return torch.nn.L1Loss() + elif loss_function_name == "Huber": + return torch.nn.HuberLoss() + elif loss_function_name == "SmoothL1": + return torch.nn.SmoothL1Loss() + + +def _run_training_loop( + guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res +): + """Execute the training loop. + + Args: + guider: The guider object + train_sampler: The training sampler + latents: Latent tensors + num_images: Number of images + seed: Random seed + bucket_mode: Whether bucket mode is enabled + multi_res: Whether multi-resolution mode is enabled + """ + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + + if bucket_mode: + # Use first bucket's first latent as dummy for guider + dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": dummy_latent}), + dummy_latent, + train_sampler, + sigmas, + seed=noise.seed, + ) + elif multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + else: + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + + class TrainLoraNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode): default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), + io.Boolean.Input( + "bucket_mode", + default=False, + tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", + ), ], outputs=[ io.Model.Output( @@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode): algorithm, gradient_checkpointing, existing_lora, + bucket_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode): grad_accumulation_steps = grad_accumulation_steps[0] learning_rate = learning_rate[0] rank = rank[0] - optimizer = optimizer[0] - loss_function = loss_function[0] + optimizer_name = optimizer[0] + loss_function_name = loss_function[0] seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] + bucket_mode = bucket_mode[0] - # Handle latents - either single dict or list of dicts - if len(latents) == 1: - latents = latents[0]["samples"] # Single latent dict + # Process latents based on mode + if bucket_mode: + latents = _process_latents_bucket_mode(latents) else: - latent_list = [] - for latent in latents: - latent = latent["samples"] - bs = latent.shape[0] - if bs != 1: - for sub_latent in latent: - latent_list.append(sub_latent[None]) - else: - latent_list.append(latent) - latents = latent_list + latents = _process_latents_standard_mode(latents) - # Handle conditioning - either single list or list of lists - if len(positive) == 1: - positive = positive[0] # Single conditioning list - else: - # Multiple conditioning lists - flatten - flat_positive = [] - for cond in positive: - if isinstance(cond, list): - flat_positive.extend(cond) - else: - flat_positive.append(cond) - positive = flat_positive + # Process conditioning + positive = _process_conditioning(positive) + # Setup model and dtype mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - # latents here can be list of different size latent or one large batch - if isinstance(latents, list): - all_shapes = set() - latents = [t.to(dtype) for t in latents] - for latent in latents: - all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") - if len(all_shapes) > 1: - multi_res = True - else: - multi_res = False - latents = torch.cat(latents, dim=0) - num_images = len(latents) - elif isinstance(latents, torch.Tensor): - latents = latents.to(dtype) - num_images = latents.shape[0] - else: - logging.error(f"Invalid latents type: {type(latents)}") + # Prepare latents and compute counts + latents, num_images, multi_res = _prepare_latents_and_count( + latents, dtype, bucket_mode + ) - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") - if len(positive) == 1 and num_images > 1: - positive = positive * num_images - elif len(positive) != num_images: - raise ValueError( - f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." - ) + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): - lora_sd = {} - generator = torch.Generator() - generator.manual_seed(seed) + # Setup models for training + mp.model.requires_grad_(False) # Load existing LoRA weights if provided - existing_weights = {} - existing_steps = 0 - if existing_lora != "[None]": - lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) - # Extract steps from filename like "trained_lora_10_steps_20250225_203716" - existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) - if lora_path: - existing_weights = comfy.utils.load_torch_file(lora_path) + existing_weights, existing_steps = _load_existing_lora(existing_lora) - all_weight_adapters = [] - for n, m in mp.model.named_modules(): - if hasattr(m, "weight_function"): - if m.weight is not None: - key = "{}.weight".format(n) - shape = m.weight.shape - if len(shape) >= 2: - alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get(f"{key}.dora_scale", None) - for adapter_cls in adapters: - existing_adapter = adapter_cls.load( - n, existing_weights, alpha, dora_scale - ) - if existing_adapter is not None: - break - else: - existing_adapter = None - adapter_cls = adapter_maps[algorithm] + # Setup LoRA adapters + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) - if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to( - lora_dtype - ) - else: - # Use LoRA with alpha=1.0 by default - train_adapter = adapter_cls.create_train( - m.weight, rank=rank, alpha=1.0 - ).to(lora_dtype) - for name, parameter in train_adapter.named_parameters(): - lora_sd[f"{n}.{name}"] = parameter + # Create optimizer and loss function + optimizer = _create_optimizer( + optimizer_name, lora_sd.values(), learning_rate + ) + criterion = _create_loss_function(loss_function_name) - mp.add_weight_wrapper(key, train_adapter) - all_weight_adapters.append(train_adapter) - else: - diff = torch.nn.Parameter( - torch.zeros( - m.weight.shape, dtype=lora_dtype, requires_grad=True - ) - ) - diff_module = BiasDiff(diff) - mp.add_weight_wrapper(key, BiasDiff(diff)) - all_weight_adapters.append(diff_module) - lora_sd["{}.diff".format(n)] = diff - if hasattr(m, "bias") and m.bias is not None: - key = "{}.bias".format(n) - bias = torch.nn.Parameter( - torch.zeros( - m.bias.shape, dtype=lora_dtype, requires_grad=True - ) - ) - bias_module = BiasDiff(bias) - lora_sd["{}.diff_b".format(n)] = bias - mp.add_weight_wrapper(key, BiasDiff(bias)) - all_weight_adapters.append(bias_module) - - if optimizer == "Adam": - optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) - elif optimizer == "AdamW": - optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) - elif optimizer == "SGD": - optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) - elif optimizer == "RMSprop": - optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) - - # Setup loss function based on selection - if loss_function == "MSE": - criterion = torch.nn.MSELoss() - elif loss_function == "L1": - criterion = torch.nn.L1Loss() - elif loss_function == "Huber": - criterion = torch.nn.HuberLoss() - elif loss_function == "SmoothL1": - criterion = torch.nn.SmoothL1Loss() - - # setup models + # Setup gradient checkpointing if gradient_checkpointing: for m in find_all_highest_child_module_with_forward( mp.model.diffusion_model ): patch(m) - mp.model.requires_grad_(False) + + torch.cuda.empty_cache() + # With force_full_load=False we should be able to have offloading + # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( [mp], memory_required=1e20, force_full_load=True ) + torch.cuda.empty_cache() - # Setup sampler and guider like in test script + # Setup loss tracking loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) - train_sampler = TrainSampler( - criterion, - optimizer, - loss_callback=loss_callback, - batch_size=batch_size, - grad_acc=grad_accumulation_steps, - total_steps=steps * grad_accumulation_steps, - seed=seed, - training_dtype=dtype, - real_dataset=latents if multi_res else None, - ) - guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) - guider.set_conds(positive) # Set conditioning from input + # Create sampler + if bucket_mode: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + bucket_latents=latents, + ) + else: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + real_dataset=latents if multi_res else None, + ) - # Training loop + # Setup guider + guider = TrainGuider(mp) + guider.set_conds(positive) + + # Run training loop try: - # Generate dummy sigmas and noise - sigmas = torch.tensor(range(num_images)) - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) - if multi_res: - # use first latent as dummy latent if multi_res - latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) - guider.sample( - noise.generate_noise({"samples": latents}), - latents, + _run_training_loop( + guider, train_sampler, - sigmas, - seed=noise.seed, + latents, + num_images, + seed, + bucket_mode, + multi_res, ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer + # Finalize adapters for adapter in all_weight_adapters: adapter.requires_grad_(False) @@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode): return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode): +class LoraModelLoader(io.ComfyNode):# @classmethod def define_schema(cls): return io.Schema( From bf7dc63bd6acdedca67598856e05080d90eeec90 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:29:32 -0800 Subject: [PATCH 066/148] skip_load_model -> force_full_load (#11390) This should be a bit more clear and less prone to potential breakage if the logic of the load models changes a bit. --- comfy/sampler_helpers.py | 9 ++++----- comfy_extras/nodes_train.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index e158e8a84..9134e6d71 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,21 +122,20 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - models_list = [model] if not skip_load_model else [] - comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 88bc8c8e8..364804205 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -44,7 +44,7 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): noise.shape, self.conds, self.model_options, - skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute() + force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded ) ) device = self.model_patcher.load_device From 1ca89b810e921efce95fb4d254a8c6c93180450b Mon Sep 17 00:00:00 2001 From: ric-yu Date: Wed, 17 Dec 2025 21:44:31 -0800 Subject: [PATCH 067/148] Add unified jobs API with /api/jobs endpoints (#11054) * feat: create a /jobs api to return queue and history jobs * update unused vars * include priority * create jobs helper file * fix ruff * update how we set error message * include execution error in both responses * rename error -> failed, fix output shape * re-use queue and history functions * set workflow id * allow srot by exec duration * fix tests * send priority and remove error msg * use ws messages to get start and end times * revert main.py fully * refactor: move all /jobs business logic to jobs.py * fix failing test * remove some tests * fix non dict nodes * address comments * filter by workflow id and remove null fields * add clearer typing - remove get("..") or .. * refactor query params to top get_job(s) doc, add remove_sensitive_from_queue * add brief comment explaining why we skip animated * comment that format field is for frontend backward compatibility * fix whitespace --------- Co-authored-by: Jedrzej Kosinski Co-authored-by: guill --- comfy_execution/jobs.py | 291 ++++++++++++++++++++++++ server.py | 135 ++++++++++- tests/execution/test_execution.py | 134 +++++++++++ tests/execution/test_jobs.py | 361 ++++++++++++++++++++++++++++++ 4 files changed, 918 insertions(+), 3 deletions(-) create mode 100644 comfy_execution/jobs.py create mode 100644 tests/execution/test_jobs.py diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py new file mode 100644 index 000000000..59fb49357 --- /dev/null +++ b/comfy_execution/jobs.py @@ -0,0 +1,291 @@ +""" +Job utilities for the /api/jobs endpoint. +Provides normalization and helper functions for job status tracking. +""" + +from typing import Optional + +from comfy_api.internal import prune_dict + + +class JobStatus: + """Job status constants.""" + PENDING = 'pending' + IN_PROGRESS = 'in_progress' + COMPLETED = 'completed' + FAILED = 'failed' + + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + + +# Media types that can be previewed in the frontend +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'}) + +# 3D file extensions for preview fallback (no dedicated media_type exists) +THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'}) + + +def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: + """Extract create_time and workflow_id from extra_data. + + Returns: + tuple: (create_time, workflow_id) + """ + create_time = extra_data.get('create_time') + extra_pnginfo = extra_data.get('extra_pnginfo', {}) + workflow_id = extra_pnginfo.get('workflow', {}).get('id') + return create_time, workflow_id + + +def is_previewable(media_type: str, item: dict) -> bool: + """ + Check if an output item is previewable. + Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts + Maintains backwards compatibility with existing logic. + + Priority: + 1. media_type is 'images', 'video', or 'audio' + 2. format field starts with 'video/' or 'audio/' + 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb) + """ + if media_type in PREVIEWABLE_MEDIA_TYPES: + return True + + # Check format field (MIME type). + # Maintains backwards compatibility with how custom node outputs are handled in the frontend. + fmt = item.get('format', '') + if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')): + return True + + # Check for 3D files by extension + filename = item.get('filename', '').lower() + if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS): + return True + + return False + + +def normalize_queue_item(item: tuple, status: str) -> dict: + """Convert queue item tuple to unified job dict. + + Expects item with sensitive data already removed (5 elements). + """ + priority, prompt_id, _, extra_data, _ = item + create_time, workflow_id = _extract_job_metadata(extra_data) + + return prune_dict({ + 'id': prompt_id, + 'status': status, + 'priority': priority, + 'create_time': create_time, + 'outputs_count': 0, + 'workflow_id': workflow_id, + }) + + +def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict: + """Convert history item dict to unified job dict. + + History items have sensitive data already removed (prompt tuple has 5 elements). + """ + prompt_tuple = history_item['prompt'] + priority, _, prompt, extra_data, _ = prompt_tuple + create_time, workflow_id = _extract_job_metadata(extra_data) + + status_info = history_item.get('status', {}) + status_str = status_info.get('status_str') if status_info else None + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.FAILED + else: + status = JobStatus.COMPLETED + + outputs = history_item.get('outputs', {}) + outputs_count, preview_output = get_outputs_summary(outputs) + + execution_error = None + execution_start_time = None + execution_end_time = None + if status_info: + messages = status_info.get('messages', []) + for entry in messages: + if isinstance(entry, (list, tuple)) and len(entry) >= 2: + event_name, event_data = entry[0], entry[1] + if isinstance(event_data, dict): + if event_name == 'execution_start': + execution_start_time = event_data.get('timestamp') + elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'): + execution_end_time = event_data.get('timestamp') + if event_name == 'execution_error': + execution_error = event_data + + job = prune_dict({ + 'id': prompt_id, + 'status': status, + 'priority': priority, + 'create_time': create_time, + 'execution_start_time': execution_start_time, + 'execution_end_time': execution_end_time, + 'execution_error': execution_error, + 'outputs_count': outputs_count, + 'preview_output': preview_output, + 'workflow_id': workflow_id, + }) + + if include_outputs: + job['outputs'] = outputs + job['execution_status'] = status_info + job['workflow'] = { + 'prompt': prompt, + 'extra_data': extra_data, + } + + return job + + +def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: + """ + Count outputs and find preview in a single pass. + Returns (outputs_count, preview_output). + + Preview priority (matching frontend): + 1. type="output" with previewable media + 2. Any previewable media + """ + count = 0 + preview_output = None + fallback_preview = None + + for node_id, node_outputs in outputs.items(): + if not isinstance(node_outputs, dict): + continue + for media_type, items in node_outputs.items(): + # 'animated' is a boolean flag, not actual output items + if media_type == 'animated' or not isinstance(items, list): + continue + + for item in items: + if not isinstance(item, dict): + continue + count += 1 + + if preview_output is None and is_previewable(media_type, item): + enriched = { + **item, + 'nodeId': node_id, + 'mediaType': media_type + } + if item.get('type') == 'output': + preview_output = enriched + elif fallback_preview is None: + fallback_preview = enriched + + return count, preview_output or fallback_preview + + +def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]: + """Sort jobs list by specified field and order.""" + reverse = (sort_order == 'desc') + + if sort_by == 'execution_duration': + def get_sort_key(job): + start = job.get('execution_start_time', 0) + end = job.get('execution_end_time', 0) + return end - start if end and start else 0 + else: + def get_sort_key(job): + return job.get('create_time', 0) + + return sorted(jobs, key=get_sort_key, reverse=reverse) + + +def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]: + """ + Get a single job by prompt_id from history or queue. + + Args: + prompt_id: The prompt ID to look up + running: List of currently running queue items + queued: List of pending queue items + history: Dict of history items keyed by prompt_id + + Returns: + Job dict with full details, or None if not found + """ + if prompt_id in history: + return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True) + + for item in running: + if item[1] == prompt_id: + return normalize_queue_item(item, JobStatus.IN_PROGRESS) + + for item in queued: + if item[1] == prompt_id: + return normalize_queue_item(item, JobStatus.PENDING) + + return None + + +def get_all_jobs( + running: list, + queued: list, + history: dict, + status_filter: Optional[list[str]] = None, + workflow_id: Optional[str] = None, + sort_by: str = "created_at", + sort_order: str = "desc", + limit: Optional[int] = None, + offset: int = 0 +) -> tuple[list[dict], int]: + """ + Get all jobs (running, pending, completed) with filtering and sorting. + + Args: + running: List of currently running queue items + queued: List of pending queue items + history: Dict of history items keyed by prompt_id + status_filter: List of statuses to include (from JobStatus.ALL) + workflow_id: Filter by workflow ID + sort_by: Field to sort by ('created_at', 'execution_duration') + sort_order: 'asc' or 'desc' + limit: Maximum number of items to return + offset: Number of items to skip + + Returns: + tuple: (jobs_list, total_count) + """ + jobs = [] + + if status_filter is None: + status_filter = JobStatus.ALL + + if JobStatus.IN_PROGRESS in status_filter: + for item in running: + jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS)) + + if JobStatus.PENDING in status_filter: + for item in queued: + jobs.append(normalize_queue_item(item, JobStatus.PENDING)) + + include_completed = JobStatus.COMPLETED in status_filter + include_failed = JobStatus.FAILED in status_filter + if include_completed or include_failed: + for prompt_id, history_item in history.items(): + is_failed = history_item.get('status', {}).get('status_str') == 'error' + if (is_failed and include_failed) or (not is_failed and include_completed): + jobs.append(normalize_history_item(prompt_id, history_item)) + + if workflow_id: + jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] + + jobs = apply_sorting(jobs, sort_by, sort_order) + + total_count = len(jobs) + + if offset > 0: + jobs = jobs[offset:] + if limit is not None: + jobs = jobs[:limit] + + return (jobs, total_count) diff --git a/server.py b/server.py index ac4f42222..c27f8be7d 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import time import nodes import folder_paths import execution +from comfy_execution.jobs import JobStatus, get_job, get_all_jobs import uuid import urllib import json @@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control if args.enable_manager: import comfyui_manager + +def _remove_sensitive_from_queue(queue: list) -> list: + """Remove sensitive data (index 5) from queue item tuples.""" + return [item[:5] for item in queue] + + async def send_socket_catch_exception(function, message): try: await function(message) @@ -694,6 +701,129 @@ class PromptServer(): out[node_class] = node_info(node_class) return web.json_response(out) + @routes.get("/api/jobs") + async def get_jobs(request): + """List all jobs with filtering, sorting, and pagination. + + Query parameters: + status: Filter by status (comma-separated): pending, in_progress, completed, failed + workflow_id: Filter by workflow ID + sort_by: Sort field: created_at (default), execution_duration + sort_order: Sort direction: asc, desc (default) + limit: Max items to return (positive integer) + offset: Items to skip (non-negative integer, default 0) + """ + query = request.rel_url.query + + status_param = query.get('status') + workflow_id = query.get('workflow_id') + sort_by = query.get('sort_by', 'created_at').lower() + sort_order = query.get('sort_order', 'desc').lower() + + status_filter = None + if status_param: + status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()] + invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL] + if invalid_statuses: + return web.json_response( + {"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"}, + status=400 + ) + + if sort_by not in {'created_at', 'execution_duration'}: + return web.json_response( + {"error": "sort_by must be 'created_at' or 'execution_duration'"}, + status=400 + ) + + if sort_order not in {'asc', 'desc'}: + return web.json_response( + {"error": "sort_order must be 'asc' or 'desc'"}, + status=400 + ) + + limit = None + + # If limit is provided, validate that it is a positive integer, else continue without a limit + if 'limit' in query: + try: + limit = int(query.get('limit')) + if limit <= 0: + return web.json_response( + {"error": "limit must be a positive integer"}, + status=400 + ) + except (ValueError, TypeError): + return web.json_response( + {"error": "limit must be an integer"}, + status=400 + ) + + offset = 0 + if 'offset' in query: + try: + offset = int(query.get('offset')) + if offset < 0: + offset = 0 + except (ValueError, TypeError): + return web.json_response( + {"error": "offset must be an integer"}, + status=400 + ) + + running, queued = self.prompt_queue.get_current_queue_volatile() + history = self.prompt_queue.get_history() + + running = _remove_sensitive_from_queue(running) + queued = _remove_sensitive_from_queue(queued) + + jobs, total = get_all_jobs( + running, queued, history, + status_filter=status_filter, + workflow_id=workflow_id, + sort_by=sort_by, + sort_order=sort_order, + limit=limit, + offset=offset + ) + + has_more = (offset + len(jobs)) < total + + return web.json_response({ + 'jobs': jobs, + 'pagination': { + 'offset': offset, + 'limit': limit, + 'total': total, + 'has_more': has_more + } + }) + + @routes.get("/api/jobs/{job_id}") + async def get_job_by_id(request): + """Get a single job by ID.""" + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response( + {"error": "job_id is required"}, + status=400 + ) + + running, queued = self.prompt_queue.get_current_queue_volatile() + history = self.prompt_queue.get_history(prompt_id=job_id) + + running = _remove_sensitive_from_queue(running) + queued = _remove_sensitive_from_queue(queued) + + job = get_job(job_id, running, queued, history) + if job is None: + return web.json_response( + {"error": "Job not found"}, + status=404 + ) + + return web.json_response(job) + @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) @@ -717,9 +847,8 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - remove_sensitive = lambda queue: [x[:5] for x in queue] - queue_info['queue_running'] = remove_sensitive(current_queue[0]) - queue_info['queue_pending'] = remove_sensitive(current_queue[1]) + queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0]) + queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ace0d2279..f73ca7e3c 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -99,6 +99,37 @@ class ComfyClient: with urllib.request.urlopen(url) as response: return json.loads(response.read()) + def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None): + url = "http://{}/api/jobs".format(self.server_address) + params = {} + if status is not None: + params["status"] = status + if limit is not None: + params["limit"] = limit + if offset is not None: + params["offset"] = offset + if sort_by is not None: + params["sort_by"] = sort_by + if sort_order is not None: + params["sort_order"] = sort_order + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + + def get_job(self, job_id): + url = "http://{}/api/jobs/{}".format(self.server_address, job_id) + try: + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + except urllib.error.HTTPError as e: + if e.code == 404: + return None + raise + def set_test_name(self, name): self.test_name = name @@ -877,3 +908,106 @@ class TestExecution: result = client.get_all_history(max_items=5, offset=len(all_history) - 1) assert len(result) <= 1, "Should return at most 1 item when offset is near end" + + # Jobs API tests + def test_jobs_api_job_structure( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that job objects have required fields""" + self._create_history_item(client, builder) + + jobs_response = client.get_jobs(status="completed", limit=1) + assert len(jobs_response["jobs"]) > 0, "Should have at least one job" + + job = jobs_response["jobs"][0] + assert "id" in job, "Job should have id" + assert "status" in job, "Job should have status" + assert "create_time" in job, "Job should have create_time" + assert "outputs_count" in job, "Job should have outputs_count" + assert "preview_output" in job, "Job should have preview_output" + + def test_jobs_api_preview_output_structure( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that preview_output has correct structure""" + self._create_history_item(client, builder) + + jobs_response = client.get_jobs(status="completed", limit=1) + job = jobs_response["jobs"][0] + + if job["preview_output"] is not None: + preview = job["preview_output"] + assert "filename" in preview, "Preview should have filename" + assert "nodeId" in preview, "Preview should have nodeId" + assert "mediaType" in preview, "Preview should have mediaType" + + def test_jobs_api_pagination( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API pagination""" + for _ in range(5): + self._create_history_item(client, builder) + + first_page = client.get_jobs(limit=2, offset=0) + second_page = client.get_jobs(limit=2, offset=2) + + assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs" + assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs" + + first_ids = {j["id"] for j in first_page["jobs"]} + second_ids = {j["id"] for j in second_page["jobs"]} + assert first_ids.isdisjoint(second_ids), "Pages should have different jobs" + + def test_jobs_api_sorting( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API sorting""" + for _ in range(3): + self._create_history_item(client, builder) + + desc_jobs = client.get_jobs(sort_order="desc") + asc_jobs = client.get_jobs(sort_order="asc") + + if len(desc_jobs["jobs"]) >= 2: + desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]] + asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]] + if len(desc_times) >= 2: + assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first" + if len(asc_times) >= 2: + assert asc_times == sorted(asc_times), "Asc should be oldest first" + + def test_jobs_api_status_filter( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API status filtering""" + self._create_history_item(client, builder) + + completed_jobs = client.get_jobs(status="completed") + assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history" + + for job in completed_jobs["jobs"]: + assert job["status"] == "completed", "Should only return completed jobs" + + # Pending jobs are transient - just verify filter doesn't error + pending_jobs = client.get_jobs(status="pending") + for job in pending_jobs["jobs"]: + assert job["status"] == "pending", "Should only return pending jobs" + + def test_get_job_by_id( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test getting a single job by ID""" + result = self._create_history_item(client, builder) + prompt_id = result.get_prompt_id() + + job = client.get_job(prompt_id) + assert job is not None, "Should find the job" + assert job["id"] == prompt_id, "Job ID should match" + assert "outputs" in job, "Single job should include outputs" + + def test_get_job_not_found( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test getting a non-existent job returns 404""" + job = client.get_job("nonexistent-job-id") + assert job is None, "Non-existent job should return None" diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py new file mode 100644 index 000000000..918c8080a --- /dev/null +++ b/tests/execution/test_jobs.py @@ -0,0 +1,361 @@ +"""Unit tests for comfy_execution/jobs.py""" + +from comfy_execution.jobs import ( + JobStatus, + is_previewable, + normalize_queue_item, + normalize_history_item, + get_outputs_summary, + apply_sorting, +) + + +class TestJobStatus: + """Test JobStatus constants.""" + + def test_status_values(self): + """Status constants should have expected string values.""" + assert JobStatus.PENDING == 'pending' + assert JobStatus.IN_PROGRESS == 'in_progress' + assert JobStatus.COMPLETED == 'completed' + assert JobStatus.FAILED == 'failed' + + def test_all_contains_all_statuses(self): + """ALL should contain all status values.""" + assert JobStatus.PENDING in JobStatus.ALL + assert JobStatus.IN_PROGRESS in JobStatus.ALL + assert JobStatus.COMPLETED in JobStatus.ALL + assert JobStatus.FAILED in JobStatus.ALL + assert len(JobStatus.ALL) == 4 + + +class TestIsPreviewable: + """Unit tests for is_previewable()""" + + def test_previewable_media_types(self): + """Images, video, audio media types should be previewable.""" + for media_type in ['images', 'video', 'audio']: + assert is_previewable(media_type, {}) is True + + def test_non_previewable_media_types(self): + """Other media types should not be previewable.""" + for media_type in ['latents', 'text', 'metadata', 'files']: + assert is_previewable(media_type, {}) is False + + def test_3d_extensions_previewable(self): + """3D file extensions should be previewable regardless of media_type.""" + for ext in ['.obj', '.fbx', '.gltf', '.glb']: + item = {'filename': f'model{ext}'} + assert is_previewable('files', item) is True + + def test_3d_extensions_case_insensitive(self): + """3D extension check should be case insensitive.""" + item = {'filename': 'MODEL.GLB'} + assert is_previewable('files', item) is True + + def test_video_format_previewable(self): + """Items with video/ format should be previewable.""" + item = {'format': 'video/mp4'} + assert is_previewable('files', item) is True + + def test_audio_format_previewable(self): + """Items with audio/ format should be previewable.""" + item = {'format': 'audio/wav'} + assert is_previewable('files', item) is True + + def test_other_format_not_previewable(self): + """Items with other format should not be previewable.""" + item = {'format': 'application/json'} + assert is_previewable('files', item) is False + + +class TestGetOutputsSummary: + """Unit tests for get_outputs_summary()""" + + def test_empty_outputs(self): + """Empty outputs should return 0 count and None preview.""" + count, preview = get_outputs_summary({}) + assert count == 0 + assert preview is None + + def test_counts_across_multiple_nodes(self): + """Outputs from multiple nodes should all be counted.""" + outputs = { + 'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]}, + 'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]}, + 'node3': {'images': [ + {'filename': 'c.png', 'type': 'output'}, + {'filename': 'd.png', 'type': 'output'} + ]} + } + count, preview = get_outputs_summary(outputs) + assert count == 4 + + def test_skips_animated_key_and_non_list_values(self): + """The 'animated' key and non-list values should be skipped.""" + outputs = { + 'node1': { + 'images': [{'filename': 'test.png', 'type': 'output'}], + 'animated': [True], # Should skip due to key name + 'metadata': 'string', # Should skip due to non-list + 'count': 42 # Should skip due to non-list + } + } + count, preview = get_outputs_summary(outputs) + assert count == 1 + + def test_preview_prefers_type_output(self): + """Items with type='output' should be preferred for preview.""" + outputs = { + 'node1': { + 'images': [ + {'filename': 'temp.png', 'type': 'temp'}, + {'filename': 'output.png', 'type': 'output'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 2 + assert preview['filename'] == 'output.png' + + def test_preview_fallback_when_no_output_type(self): + """If no type='output', should use first previewable.""" + outputs = { + 'node1': { + 'images': [ + {'filename': 'temp1.png', 'type': 'temp'}, + {'filename': 'temp2.png', 'type': 'temp'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert preview['filename'] == 'temp1.png' + + def test_non_previewable_media_types_counted_but_no_preview(self): + """Non-previewable media types should be counted but not used as preview.""" + outputs = { + 'node1': { + 'latents': [ + {'filename': 'latent1.safetensors'}, + {'filename': 'latent2.safetensors'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 2 + assert preview is None + + def test_previewable_media_types(self): + """Images, video, and audio media types should be previewable.""" + for media_type in ['images', 'video', 'audio']: + outputs = { + 'node1': { + media_type: [{'filename': 'test.file', 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"{media_type} should be previewable" + + def test_3d_files_previewable(self): + """3D file extensions should be previewable.""" + for ext in ['.obj', '.fbx', '.gltf', '.glb']: + outputs = { + 'node1': { + 'files': [{'filename': f'model{ext}', 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"3D file {ext} should be previewable" + + def test_format_mime_type_previewable(self): + """Files with video/ or audio/ format should be previewable.""" + for fmt in ['video/x-custom', 'audio/x-custom']: + outputs = { + 'node1': { + 'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"Format {fmt} should be previewable" + + def test_preview_enriched_with_node_metadata(self): + """Preview should include nodeId, mediaType, and original fields.""" + outputs = { + 'node123': { + 'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview['nodeId'] == 'node123' + assert preview['mediaType'] == 'images' + assert preview['subfolder'] == 'outputs' + + +class TestApplySorting: + """Unit tests for apply_sorting()""" + + def test_sort_by_create_time_desc(self): + """Default sort by create_time descending.""" + jobs = [ + {'id': 'a', 'create_time': 100}, + {'id': 'b', 'create_time': 300}, + {'id': 'c', 'create_time': 200}, + ] + result = apply_sorting(jobs, 'created_at', 'desc') + assert [j['id'] for j in result] == ['b', 'c', 'a'] + + def test_sort_by_create_time_asc(self): + """Sort by create_time ascending.""" + jobs = [ + {'id': 'a', 'create_time': 100}, + {'id': 'b', 'create_time': 300}, + {'id': 'c', 'create_time': 200}, + ] + result = apply_sorting(jobs, 'created_at', 'asc') + assert [j['id'] for j in result] == ['a', 'c', 'b'] + + def test_sort_by_execution_duration(self): + """Sort by execution_duration should order by duration.""" + jobs = [ + {'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s + {'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s + {'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s + ] + result = apply_sorting(jobs, 'execution_duration', 'desc') + assert [j['id'] for j in result] == ['a', 'c', 'b'] + + def test_sort_with_none_values(self): + """Jobs with None values should sort as 0.""" + jobs = [ + {'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, + {'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None}, + {'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, + ] + result = apply_sorting(jobs, 'execution_duration', 'asc') + assert result[0]['id'] == 'b' # None treated as 0, comes first + + +class TestNormalizeQueueItem: + """Unit tests for normalize_queue_item()""" + + def test_basic_normalization(self): + """Queue item should be normalized to job dict.""" + item = ( + 10, # priority/number + 'prompt-123', # prompt_id + {'nodes': {}}, # prompt + { + 'create_time': 1234567890, + 'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}} + }, # extra_data + ['node1'], # outputs_to_execute + ) + job = normalize_queue_item(item, JobStatus.PENDING) + + assert job['id'] == 'prompt-123' + assert job['status'] == 'pending' + assert job['priority'] == 10 + assert job['create_time'] == 1234567890 + assert 'execution_start_time' not in job + assert 'execution_end_time' not in job + assert 'execution_error' not in job + assert 'preview_output' not in job + assert job['outputs_count'] == 0 + assert job['workflow_id'] == 'workflow-abc' + + +class TestNormalizeHistoryItem: + """Unit tests for normalize_history_item()""" + + def test_completed_job(self): + """Completed history item should have correct status and times from messages.""" + history_item = { + 'prompt': ( + 5, # priority + 'prompt-456', + {'nodes': {}}, + { + 'create_time': 1234567890000, + 'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}} + }, + ['node1'], + ), + 'status': { + 'status_str': 'success', + 'completed': True, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}), + ('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}), + ] + }, + 'outputs': {}, + } + job = normalize_history_item('prompt-456', history_item) + + assert job['id'] == 'prompt-456' + assert job['status'] == 'completed' + assert job['priority'] == 5 + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567893000 + assert job['workflow_id'] == 'workflow-xyz' + + def test_failed_job(self): + """Failed history item should have failed status and error from messages.""" + history_item = { + 'prompt': ( + 5, + 'prompt-789', + {'nodes': {}}, + {'create_time': 1234567890000}, + ['node1'], + ), + 'status': { + 'status_str': 'error', + 'completed': False, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}), + ('execution_error', { + 'prompt_id': 'prompt-789', + 'node_id': '5', + 'node_type': 'KSampler', + 'exception_message': 'CUDA out of memory', + 'exception_type': 'RuntimeError', + 'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'], + 'timestamp': 1234567891000, + }) + ] + }, + 'outputs': {}, + } + + job = normalize_history_item('prompt-789', history_item) + assert job['status'] == 'failed' + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567891000 + assert job['execution_error']['node_id'] == '5' + assert job['execution_error']['node_type'] == 'KSampler' + assert job['execution_error']['exception_message'] == 'CUDA out of memory' + + def test_include_outputs(self): + """When include_outputs=True, should include full output data.""" + history_item = { + 'prompt': ( + 5, + 'prompt-123', + {'nodes': {'1': {}}}, + {'create_time': 1234567890, 'client_id': 'abc'}, + ['node1'], + ), + 'status': {'status_str': 'success', 'completed': True, 'messages': []}, + 'outputs': {'node1': {'images': [{'filename': 'test.png'}]}}, + } + job = normalize_history_item('prompt-123', history_item, include_outputs=True) + + assert 'outputs' in job + assert 'workflow' in job + assert 'execution_status' in job + assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}} + assert job['workflow'] == { + 'prompt': {'nodes': {'1': {}}}, + 'extra_data': {'create_time': 1234567890, 'client_id': 'abc'}, + } From e8ebbe668e82ab0f3c0842afa79d255329eb76ac Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 19 Dec 2025 06:09:29 +0800 Subject: [PATCH 068/148] chore: update workflow templates to v0.7.60 (#11403) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9b9e61683..54696395f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.59 +comfyui-workflow-templates==0.7.60 comfyui-embedded-docs==0.3.1 torch torchsde From e4fb3a3572c94d8f2ef73ddd18d2a6966ed5a1e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:45:33 -0800 Subject: [PATCH 069/148] Support loading Wan/Qwen VAEs with different in/out channels. (#11405) --- comfy/ldm/wan/vae.py | 11 +++++++---- comfy/sd.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index ccbb25822..08315f1a8 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -227,6 +227,7 @@ class Encoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + input_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -245,7 +246,7 @@ class Encoder3d(nn.Module): scale = 1.0 # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1) # downsample blocks downsamples = [] @@ -331,6 +332,7 @@ class Decoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + output_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -378,7 +380,7 @@ class Decoder3d(nn.Module): # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + CausalConv3d(out_dim, output_channels, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -449,6 +451,7 @@ class WanVAE(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], + image_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -460,11 +463,11 @@ class WanVAE(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/sd.py b/comfy/sd.py index 1cad98aef..f95c78892 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -546,7 +546,8 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.output_channels = sd["encoder.conv1.weight"].shape[1] + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) From 6a2678ac65ff690e24771a4c64ce96f7a9824fa4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 15:22:38 -0800 Subject: [PATCH 070/148] Trim/pad channels in VAE code. (#11406) --- comfy/sd.py | 33 ++++++++++++++++++++++++--------- nodes.py | 4 ++-- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f95c78892..c2a9728f3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -321,6 +321,7 @@ class VAE: self.latent_channels = 4 self.latent_dim = 2 self.output_channels = 3 + self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] @@ -435,6 +436,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 2048 self.downscale_ratio = 2048 self.latent_dim = 1 @@ -547,6 +549,7 @@ class VAE: self.latent_dim = 3 self.latent_channels = 16 self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.pad_channel_value = 1.0 ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] @@ -583,6 +586,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.latent_channels = 8 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 4096 self.downscale_ratio = 4096 self.latent_dim = 2 @@ -691,17 +695,28 @@ class VAE: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): - if not self.crop_input: - return pixels + if self.crop_input: + downscale_ratio = self.spacial_compression_encode() - downscale_ratio = self.spacial_compression_encode() + dims = pixels.shape[1:-1] + for d in range(len(dims)): + x = (dims[d] // downscale_ratio) * downscale_ratio + x_offset = (dims[d] % downscale_ratio) // 2 + if x != dims[d]: + pixels = pixels.narrow(d + 1, x_offset, x) - dims = pixels.shape[1:-1] - for d in range(len(dims)): - x = (dims[d] // downscale_ratio) * downscale_ratio - x_offset = (dims[d] % downscale_ratio) // 2 - if x != dims[d]: - pixels = pixels.narrow(d + 1, x_offset, x) + if pixels.shape[-1] > self.output_channels: + pixels = pixels[..., :self.output_channels] + elif pixels.shape[-1] < self.output_channels: + if self.pad_channel_value is not None: + if isinstance(self.pad_channel_value, str): + mode = self.pad_channel_value + value = None + else: + mode = "constant" + value = self.pad_channel_value + + pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) return pixels def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): diff --git a/nodes.py b/nodes.py index 3fa543294..b13ceb578 100644 --- a/nodes.py +++ b/nodes.py @@ -343,7 +343,7 @@ class VAEEncode: CATEGORY = "latent" def encode(self, vae, pixels): - t = vae.encode(pixels[:,:,:,:3]) + t = vae.encode(pixels) return ({"samples":t}, ) class VAEEncodeTiled: @@ -361,7 +361,7 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) class VAEEncodeForInpaint: From 28eaab608bc34c4e3b1886b1bddbb429453249d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:21:14 -0800 Subject: [PATCH 071/148] Diffusion model part of Qwen Image Layered. (#11408) Only thing missing after this is some nodes to make using it easier. --- comfy/ldm/qwen_image/model.py | 63 ++++++++++++++++++++++------------- comfy/model_detection.py | 3 ++ 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 902af30ed..00c597535 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis): class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding( @@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module): operations=operations ) - def forward(self, timestep, hidden_states): + self.use_additional_t_cond = use_additional_t_cond + if self.use_additional_t_cond: + self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype) + + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + if self.use_additional_t_cond: + if addition_t_cond is None: + addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long) + timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype) + return timesteps_emb @@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module): num_attention_heads: int = 24, joint_attention_dim: int = 3584, pooled_projection_dim: int = 768, - guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), default_ref_method="index", image_model=None, final_layer=True, + use_additional_t_cond=False, dtype=None, device=None, operations=None, @@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module): self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, + use_additional_t_cond=use_additional_t_cond, dtype=dtype, device=device, operations=operations @@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module): patch_size = self.patch_size hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + t_len = t h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) - return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device) - def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + if t_len > 1: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1) + else: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index + + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs) def _forward( self, @@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module): timesteps, context, attention_mask=None, - guidance: torch.Tensor = None, ref_latents=None, + additional_t_cond=None, transformer_options={}, control=None, **kwargs @@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module): index = 0 ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + negative_ref_method = ref_method == "negative_index" timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 h_offset = 0 w_offset = 0 + elif negative_ref_method: + index -= 1 + h_offset = 0 + w_offset = 0 else: index = 1 h_offset = 0 @@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) patches_replace = transformer_options.get("patches_replace", {}) patches = transformer_options.get("patches", {}) @@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) - hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7148c77fd..84fd409fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" + if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered + dit_config["use_additional_t_cond"] = True + dit_config["default_ref_method"] = "negative_index" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 From 894802b0f9c3a247f5609db89ec3be24eac7fd2b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 19:21:40 -0800 Subject: [PATCH 072/148] Add LatentCutToBatch node. (#11411) --- comfy_extras/nodes_latent.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index e439b18ef..2815c5ffc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -5,6 +5,7 @@ import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging +import math def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode): samples_out["samples"] = torch.narrow(s1, dim, index, amount) return io.NodeOutput(samples_out) +class LatentCutToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCutToBatch", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["t", "x", "y"]), + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if dim < 2: + return io.NodeOutput(samples) + + s = s1.movedim(dim, 1) + if s.shape[1] < slice_size: + slice_size = s.shape[1] + elif s.shape[1] % slice_size != 0: + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] + new_shape = [-1, slice_size] + list(s.shape[2:]) + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) + return io.NodeOutput(samples_out) + class LatentBatch(io.ComfyNode): @classmethod def define_schema(cls): @@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension): LatentInterpolate, LatentConcat, LatentCut, + LatentCutToBatch, LatentBatch, LatentBatchSeedBehavior, LatentApplyOperation, From 5b4d0664c87dc62a8361fe292b0bdac42681aef8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:02:49 +0200 Subject: [PATCH 073/148] add Flux2MaxImage API Node (#11420) --- comfy_api_nodes/nodes_bfl.py | 68 ++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 8826dea0c..ce077d6b3 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,10 +1,8 @@ -from inspect import cleandoc - import torch from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -28,7 +26,7 @@ from comfy_api_nodes.util import ( ) -def convert_mask_to_image(mask: torch.Tensor): +def convert_mask_to_image(mask: Input.Image): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ @@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor): class FluxProUltraImageNode(IO.ComfyNode): - """ - Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode): node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( "prompt", @@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: torch.Tensor | None = None, + image_prompt: Input.Image | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode): - """ - Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode): node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: torch.Tensor | None = None, + input_image: Input.Image | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextMaxImageNode(FluxKontextProImageNode): - """ - Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. - """ - DESCRIPTION = cleandoc(__doc__ or "") + DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" class FluxProExpandNode(IO.ComfyNode): - """ - Outpaints image based on prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode): node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, prompt_upsampling: bool, top: int, @@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode): - """ - Inpaints image based on mask and prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode): node_id="FluxProFillNode", display_name="Flux.1 Fill Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), IO.Mask.Input("mask"), @@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - mask: torch.Tensor, + image: Input.Image, + mask: Input.Image, prompt: str, prompt_upsampling: bool, steps: int, @@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode): + NODE_ID = "Flux2ProImageNode" + DISPLAY_NAME = "Flux.2 [pro] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="Flux2ProImageNode", - display_name="Flux.2 [pro] Image", + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, category="api node/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ @@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode): ), IO.Boolean.Input( "prompt_upsampling", - default=False, + default=True, tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", + "If active, automatically modifies the prompt for more creative generation.", ), - IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode): height: int, seed: int, prompt_upsampling: bool, - images: torch.Tensor | None = None, + images: Input.Image | None = None, ) -> IO.NodeOutput: reference_images = {} if images is not None: @@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode): reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) initial_response = await sync_op( cls, - ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), response_model=BFLFluxProGenerateResponse, data=Flux2ProGenerateRequest( prompt=prompt, @@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2MaxImageNode(Flux2ProImageNode): + + NODE_ID = "Flux2MaxImageNode" + DISPLAY_NAME = "Flux.2 [max] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension): FluxProExpandNode, FluxProFillNode, Flux2ProImageNode, + Flux2MaxImageNode, ] From 8376ff6831b145eadc3339e1901ffe02386ab86a Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 20 Dec 2025 03:41:56 +0900 Subject: [PATCH 074/148] bump comfyui_manager version to the 4.0.3b7 (#11422) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 5ef0d3a1d..2300f0c70 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b5 +comfyui_manager==4.0.3b7 From cc4ddba1b68abdc64ef5a701fd0571fcf2faf98d Mon Sep 17 00:00:00 2001 From: BradPepersAMD Date: Fri, 19 Dec 2025 15:01:50 -0700 Subject: [PATCH 075/148] Allow enabling use of MIOpen by setting COMFYUI_ENABLE_MIOPEN=1 as an env var (#11366) --- comfy/model_management.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..1889ab0ac 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,7 @@ import importlib import platform import weakref import gc +import os class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -333,13 +334,15 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] +AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) From 809ce687493db84f6743639adf9b600753b6188e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:59:25 -0800 Subject: [PATCH 076/148] Support nested tensor denoise masks. (#11431) --- comfy/samplers.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8340d376c..1989ef107 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -984,9 +984,6 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) @@ -1013,6 +1010,24 @@ class CFGGuider: else: latent_shapes = [latent_image.shape] + if denoise_mask is not None: + if denoise_mask.is_nested: + denoise_masks = denoise_mask.unbind() + denoise_masks = denoise_masks[:len(latent_shapes)] + else: + denoise_masks = [denoise_mask] + + for i in range(len(denoise_masks), len(latent_shapes)): + denoise_masks.append(torch.ones(latent_shapes[i])) + + for i in range(len(denoise_masks)): + denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + + if len(denoise_masks) > 1: + denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + else: + denoise_mask = denoise_masks[0] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) From 514c24d756997c3131c57aa21578a09429096eca Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:22:45 -0700 Subject: [PATCH 077/148] Fix error from logging line (#11423) Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2979b3ca1..1e0f86026 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC): # if multiple conds, split based on primary region if self.split_conds_to_windows and len(cond_in) > 1: region = window.get_region_index(len(cond_in)) - logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}") cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: From 0aa7fa464efc4ecc35a145048c06d325c75fbf2b Mon Sep 17 00:00:00 2001 From: woctordho Date: Sat, 20 Dec 2025 13:16:46 +0800 Subject: [PATCH 078/148] Implement sliding attention in Gemma3 (#11409) --- comfy/text_encoders/llama.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 0d07ac8c6..ed29e014d 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,7 +3,6 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any import math -import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -177,7 +176,7 @@ class Gemma3_4B_Config: num_key_value_heads: int = 4 max_position_embeddings: int = 131072 rms_norm_eps: float = 1e-6 - rope_theta = [10000.0, 1000000.0] + rope_theta = [1000000.0, 10000.0] transformer_type: str = "gemma3" head_dim = 256 rms_norm_add = True @@ -186,8 +185,8 @@ class Gemma3_4B_Config: rope_dims = None q_norm = "gemma3" k_norm = "gemma3" - sliding_attention = [False, False, False, False, False, 1024] - rope_scale = [1.0, 8.0] + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] final_norm: bool = True class RMSNorm(nn.Module): @@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: self.sliding_attention = False @@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module): if self.transformer_type == 'gemma3': if self.sliding_attention: if x.shape[1] > self.sliding_attention: - logging.warning("Warning: sliding attention not implemented, results may be incorrect") + sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + if attention_mask is not None: + attention_mask = attention_mask + sliding_mask + else: + attention_mask = sliding_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] From 3ab9748903a8ee51f62ae8d3eeebc1f98847f4bd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:19:47 -0800 Subject: [PATCH 079/148] Disable prompt weights on newbie te. (#11434) --- comfy/sd1_clip.py | 6 ++++-- comfy/text_encoders/lumina2.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 962948dae..c512ca5d0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -513,6 +513,8 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key + self.disable_weights = disable_weights + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. @@ -547,7 +549,7 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - if kwargs.get("disable_weights", False): + if kwargs.get("disable_weights", self.disable_weights): parsed_weights = [(text, 1.0)] else: parsed_weights = token_weights(text, 1.0) diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 7a6cfdab2..f82883ba1 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} From 767ee30f217e72797df6b018417234bf8b3f7b69 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:22:17 -0800 Subject: [PATCH 080/148] ZImageFunControlNet: Fix mask concatenation in --gpu-only (#11421) This operation trades in latents which in --gpu-only may be out of the GPU The two VAE results will follow the --gpu-only defined behaviour so follow the inpaint image device when calculating the mask in this path. --- comfy_extras/nodes_model_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 2a0cfcf18..1355b3c93 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -348,7 +348,7 @@ class ZImageControlPatch: if self.mask is None: mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] else: - mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") if latent_image is None: latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) From 31e961736a476851e2579d5d9202ed4177a71720 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:23:51 -0800 Subject: [PATCH 081/148] Fix issue with batches and newbie. (#11435) --- comfy/ldm/lumina/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 5628e2ba3..e80b1c138 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -625,7 +625,7 @@ class NextDiT(nn.Module): if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) From 4c432c11ed6f83466b8ff02569872925753a3c44 Mon Sep 17 00:00:00 2001 From: woctordho Date: Sat, 20 Dec 2025 13:57:22 +0800 Subject: [PATCH 082/148] Implement Jina CLIP v2 and NewBie dual CLIP (#11415) * Implement Jina CLIP v2 * Support quantized Gemma in NewBie dual CLIP --- comfy/model_base.py | 2 +- comfy/model_detection.py | 3 +- comfy/sd.py | 20 +++ comfy/text_encoders/jina_clip_2.py | 219 +++++++++++++++++++++++++++++ comfy/text_encoders/newbie.py | 62 ++++++++ nodes.py | 4 +- 6 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 comfy/text_encoders/jina_clip_2.py create mode 100644 comfy/text_encoders/newbie.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..c4f3c0639 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,7 +1110,7 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 84fd409fd..539e296ed 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 diff --git a/comfy/sd.py b/comfy/sd.py index c2a9728f3..7de7dd9c6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie import comfy.model_patcher import comfy.lora @@ -1008,6 +1010,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1038,6 +1041,7 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + JINA_CLIP_2 = 18 def detect_te_model(sd): @@ -1047,6 +1051,8 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] if weight.shape[-1] == 4096: @@ -1207,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..31904462b --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/nodes.py b/nodes.py index b13ceb578..7d83ecb21 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -980,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From fb478f679a2998c4f2e955bcb895cc4c55f119a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 22:02:43 -0800 Subject: [PATCH 083/148] Only apply gemma quant config to gemma model for newbie. (#11436) --- comfy/text_encoders/lumina2.py | 5 +++++ comfy/text_encoders/newbie.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index f82883ba1..b29a7cc87 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py index 31904462b..db2324576 100644 --- a/comfy/text_encoders/newbie.py +++ b/comfy/text_encoders/newbie.py @@ -57,6 +57,6 @@ def te(dtype_llama=None, llama_quantization_metadata=None): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["quantization_metadata"] = llama_quantization_metadata + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) return NewBieTEModel_ From 0899012ad60db23cbc5990d164fbd22195bafcb2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 20 Dec 2025 08:24:37 +0200 Subject: [PATCH 084/148] chore(api-nodes): by default set Watermark generation to False (#11437) --- comfy_api_nodes/apis/bytedance_api.py | 6 +++--- comfy_api_nodes/nodes_bytedance.py | 16 ++++++++-------- comfy_api_nodes/nodes_wan.py | 24 ++++++++++++------------ 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index 77cd76f9b..b8c2f618b 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel): size: str | None = Field(None) seed: int | None = Field(0, ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Image2ImageTaskCreationRequest(BaseModel): @@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel): size: str | None = Field("adaptive") seed: int | None = Field(..., ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Seedream4Options(BaseModel): @@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel): seed: int = Field(..., ge=0, le=2147483647) sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) + watermark: bool = Field(False) class ImageTaskCreationResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 57c0218d0..636cc1265 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), @@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation: str = "disabled", max_images: int = 1, seed: int = 0, - watermark: bool = True, + watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 17b680e13..1675fd863 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel): n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) class Image2ImageParametersField(BaseModel): size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) - watermark: bool = Field(True) + watermark: bool = Field(False) class Text2VideoParametersField(BaseModel): @@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode): height: int = 1024, seed: int = 0, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, ): initial_response = await sync_op( cls, @@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode): # width: int = 1024, # height: int = 1024, seed: int = 0, - watermark: bool = True, + watermark: bool = False, ): n_images = get_number_of_images(image) if n_images not in (1, 2): @@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if "480p" in size and model == "wan2.6-t2v": @@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if get_number_of_images(image) != 1: From bbb11e26081977474eec72ce36d12ec778b5a9ea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 20 Dec 2025 18:48:28 +0200 Subject: [PATCH 085/148] fix(api-nodes): Topaz 4k video upscaling (#11438) --- comfy_api_nodes/nodes_topaz.py | 35 +++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index f522756e5..b04575ad8 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = { "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", } -UPSCALER_VALUES_MAP = { - "FullHD (1080p)": 1920, - "4K (2160p)": 3840, -} class TopazImageEnhance(IO.ComfyNode): @@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", options=["low", "middle", "high"], @@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode): target_frame_rate = src_frame_rate filters = [] if upscaler_enabled: - target_width = UPSCALER_VALUES_MAP[upscaler_resolution] - target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + if "1080p" in upscaler_resolution: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 filters.append( topaz_api.VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], From 807538fe6c66bca8c91edbad14414fb4e109cbde Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 20 Dec 2025 17:02:02 -0800 Subject: [PATCH 086/148] Core release process. (#11447) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index bae955b1b..b0f62695b 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. + - Minor versions will be used for releases off the master branch. + - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release From 91bf6b6aa3d5134c1569375a34ff483d3e32e03f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 21 Dec 2025 16:59:40 -0800 Subject: [PATCH 087/148] Add node to create empty latents for qwen image layered model. (#11460) --- comfy_extras/nodes_qwen.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..fde8fac9a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,7 +3,9 @@ import comfy.utils import math from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import nodes class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class EmptyQwenImageLayeredLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyQwenImageLayeredLatentImage", + display_name="Empty Qwen Image Layered Latent", + category="latent/qwen", + inputs=[ + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + EmptyQwenImageLayeredLatentImage, ] From c176b214cc768d41892add4d4f51c5c5627cbf7b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 22 Dec 2025 08:44:49 +0200 Subject: [PATCH 088/148] extend possible duration range for Kling O1 StartEndFrame node (#11451) --- comfy_api_nodes/nodes_kling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 1a6364fa0..5294b10d4 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -858,7 +858,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): tooltip="A text prompt describing the video content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("duration", options=["5", "10"]), + IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", @@ -897,6 +897,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + if duration not in (5, 10) and end_frame is None and reference_images is None: + raise ValueError( + "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." + ) validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ From eb0e10aec449eed2bbcda82ae5b56070e61ed86f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 23 Dec 2025 05:02:41 +0800 Subject: [PATCH 089/148] Update workflow templates to v0.7.62 (#11467) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54696395f..b41dbe1d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.60 +comfyui-workflow-templates==0.7.62 comfyui-embedded-docs==0.3.1 torch torchsde From 33aa808713f7c36cd9476c53b8b67c745e9bc107 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 22 Dec 2025 13:43:24 -0800 Subject: [PATCH 090/148] Make denoised output on custom sampler nodes work with nested tensors. (#11471) --- comfy_extras/nodes_custom_sampler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac1..993889d9d 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -760,8 +760,12 @@ class SamplerCustom(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -948,8 +952,12 @@ class SamplerCustomAdvanced(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) From f4f44bb8073d02597aca61193fec6143292a0b88 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:10:27 +0200 Subject: [PATCH 091/148] api-nodes: use new custom endpoint for Nano Banana (#11311) --- comfy_api_nodes/apis/gemini_api.py | 1 + comfy_api_nodes/nodes_gemini.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c9..d81337dae 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..e8ed7e797 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): From 22ff1bbfcb532a294b200a90270b772a339d334e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 24 Dec 2025 09:48:45 +0800 Subject: [PATCH 092/148] chore: update workflow templates to v0.7.63 (#11482) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b41dbe1d7..59ac599c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.62 +comfyui-workflow-templates==0.7.63 comfyui-embedded-docs==0.3.1 torch torchsde From e4c61d75555036fa28b6bb34e5fd67b007c9f391 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 Dec 2025 20:50:02 -0500 Subject: [PATCH 093/148] ComfyUI v0.6.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b45309198..1f28e2407 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.1" +__version__ = "0.6.0" diff --git a/pyproject.toml b/pyproject.toml index 3a6960811..35a268bd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.5.1" +version = "0.6.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 650e716dda0a966a083f0efe299f3e83336f920e Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 24 Dec 2025 14:29:41 +0900 Subject: [PATCH 094/148] Bump comfyui-frontend-package to 1.35.9 (#11470) Co-authored-by: github-actions[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 59ac599c1..84b1882aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.34.9 +comfyui-frontend-package==1.35.9 comfyui-workflow-templates==0.7.63 comfyui-embedded-docs==0.3.1 torch From 4f067b07fb33cc1b61d91aec73ca968ba7d9c29a Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 25 Dec 2025 07:54:21 +0800 Subject: [PATCH 095/148] chore: update workflow templates to v0.7.64 (#11496) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 84b1882aa..8b670b813 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.63 +comfyui-workflow-templates==0.7.64 comfyui-embedded-docs==0.3.1 torch torchsde From 532e2850794c7b497174a0a42ac0cb1fe5b62499 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 24 Dec 2025 16:09:37 -0800 Subject: [PATCH 096/148] Add a ManualSigmas node. (#11499) Can be used to manually set the sigmas for a model. This node accepts a list of integer and floating point numbers separated with any non numeric character. --- comfy_extras/nodes_custom_sampler.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 993889d9d..f19adf4b9 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import comfy.utils import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -1013,6 +1014,25 @@ class AddNoise(io.ComfyNode): add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1052,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension): DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] From d9a76cf66e3fc6b0047692a07bc1d24f20e16e20 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:46:51 -0800 Subject: [PATCH 097/148] Specify in readme that we only support pytorch 2.4 and up. (#11512) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b0f62695b..6d09758c0 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. From 16fb6849d296259fd2bf106a6f894650d9a12072 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 27 Dec 2025 08:55:59 +0900 Subject: [PATCH 098/148] bump comfyui_manager version to the 4.0.4 (#11521) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 2300f0c70..6585b0c19 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b7 +comfyui_manager==4.0.4 From 1e4e342f54386ea4179b273c24b37bd8cbde8f37 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:03:01 -0800 Subject: [PATCH 099/148] Fix noise with ancestral samplers when inferencing on cpu. (#11528) --- comfy/k_diffusion/sampling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1ba9edad7..0949dee44 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: From 865568b7fc5fd2a5f626b22a40c363b0a5f0b399 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:16:21 +0200 Subject: [PATCH 100/148] feat(api-nodes): add Kling Motion Control node (#11493) --- comfy_api_nodes/apis/kling_api.py | 9 ++++ comfy_api_nodes/nodes_kling.py | 87 +++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 80a758466..bf54ede3e 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel): prompt: str = Field(...) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5294b10d4..58259e029 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -51,6 +51,7 @@ from comfy_api_nodes.apis import ( ) from comfy_api_nodes.apis.kling_api import ( ImageToVideoWithAudioRequest, + MotionControlRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -2163,6 +2164,91 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2188,6 +2274,7 @@ class KlingExtension(ComfyExtension): OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, ] From eff4ea0b625e851d641b8f6532ff7afe2df16b9d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:39:02 +0200 Subject: [PATCH 101/148] [V3] converted nodes_images.py to V3 schema (#11206) * converted nodes_images.py to V3 schema * fix test --- comfy_api/latest/_io.py | 5 +- comfy_api/latest/_util/__init__.py | 2 + comfy_api/latest/_util/image_types.py | 18 + comfy_extras/nodes_images.py | 683 +++++++++--------- .../comfy_extras_test/image_stitch_test.py | 2 +- 5 files changed, 351 insertions(+), 359 deletions(-) create mode 100644 comfy_api/latest/_util/image_types.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4b14e5ded..ba0b95498 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,9 +28,8 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr prune_dict, shallow_clone_class) from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -656,7 +655,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda..6313eb01b 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,5 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SVG", ] diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000..f031ed426 --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c..ce21caade 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,231 @@ from __future__ import annotations import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: + crop = execute # TODO: remove + + +class RepeatImageBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def repeat(self, image, amount): + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) - return (s,) + return IO.NodeOutput(s) -class ImageFromBatch: + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def frombatch(self, image, batch_index, length): + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove -class ImageAddNoise: +class ImageAddNoise(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + repeat = execute # TODO: remove - methods = {"default": 4, "fastest": 0, "slowest": 6} - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } - -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - RETURN_TYPES = () - FUNCTION = "save_images" + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) + save_images = execute # TODO: remove -class ImageStitch: +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + @classmethod + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +363,30 @@ Optional spacing can be added between images. images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): -class ResizeAndPadImage: @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"]), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" - - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +414,47 @@ class ResizeAndPadImage: padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +484,64 @@ class SaveSVGNode: with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" - - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" - - def get_size(self, image, unique_id=None) -> tuple[int, int]: + @classmethod + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) + + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): -class ImageRotate: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +551,57 @@ class ImageRotate: rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) + + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): -class ImageFlip: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, largest_size): + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +618,30 @@ class ImageScaleToMaxDimension: samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) + return IO.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + upscale = execute # TODO: remove + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index b5a0f022c..5c6a15ac4 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -25,7 +25,7 @@ class TestImageStitch: result = node.stitch(image1, "right", True, 0, "white", image2=None) - assert len(result) == 1 + assert len(result.result) == 1 assert torch.equal(result[0], image1) def test_basic_horizontal_stitch_right(self): From 0d2e4bdd44f61b198588c5db99bebfd5cdec286b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:55:30 +0200 Subject: [PATCH 102/148] fix(api-nodes-gemini): always force enhance_prompt to be True (#11503) --- comfy_api_nodes/nodes_veo2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380..13a6bfd91 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, ), IO.Combo.Input( From 36deef2c57eacb5d847bd709c5f3068190630612 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:56:52 +0200 Subject: [PATCH 103/148] chore(api-nodes): switch to credits instead of $ (#11489) --- comfy_api_nodes/util/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5f..f372ec7b5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -430,9 +430,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: From 2943093a5310fc96aa010a3c68c04f7c16f58a9e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 27 Dec 2025 15:54:15 -0800 Subject: [PATCH 104/148] Enable async offload by default for AMD. (#11534) --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1889ab0ac..e5554e225 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1019,8 +1019,8 @@ NUM_STREAMS = 0 if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: From 8fd07170f1b0a7eaaf5a62020cd1926dd3b5092c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 28 Dec 2025 19:07:25 -0800 Subject: [PATCH 105/148] Comment out unused norm_final in lumina/z image model. (#11545) --- comfy/ldm/lumina/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e80b1c138..afbab2ac7 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -491,7 +491,8 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: From 9ca7e143afe6f09734c9aefcc85f491c5c0dc6e0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:19:34 -0800 Subject: [PATCH 106/148] mm: discard async errors from pinning failures (#10738) Pretty much every error cudaHostRegister can throw also queues the same error on the async GPU queue. This was fixed for repinning error case, but there is the bad mmap and just enomem cases that are harder to detect. Do some dummy GPU work to clean the error state. --- comfy/model_management.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e5554e225..9fcb699bc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1126,6 +1126,16 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + torch.cuda.synchronize() + except torch.AcceleratorError: + #Dump it! We already know about it from the synchronous return + pass + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1158,6 +1168,8 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + discard_cuda_async_error() return False @@ -1186,6 +1198,8 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + discard_cuda_async_error() return False From 0e6221cc79a3f3cbf0e15a8321bfe75fcffbe667 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:26:42 -0800 Subject: [PATCH 107/148] Add some warnings for pin and unpin errors. (#11561) --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9fcb699bc..87baedd73 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1169,6 +1169,7 @@ def pin_memory(tensor): TOTAL_PINNED_MEMORY += size return True else: + logging.warning("Pin error.") discard_cuda_async_error() return False @@ -1199,6 +1200,7 @@ def unpin_memory(tensor): TOTAL_PINNED_MEMORY = 0 return True else: + logging.warning("Unpin error.") discard_cuda_async_error() return False From d7111e426a48127a97922227b03d31391eb4eba2 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Tue, 30 Dec 2025 03:07:29 +0200 Subject: [PATCH 108/148] ResizeByLongerSide: support video (#11555) (cherry picked from commit 98c6840aa4e5fd5407ba9ab113d209011e474bf6) --- comfy_extras/nodes_dataset.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 513aecf3a..5ef851bd0 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): From 25a1bfab4e19b541c2bd6f253a3b83886fb660a1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 30 Dec 2025 18:33:34 +0200 Subject: [PATCH 109/148] chore(api-nodes-bytedance): mark "seededit" as deprecated, adjust display name of Seedream (#11490) --- comfy_api_nodes/nodes_bytedance.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 636cc1265..d4a2cfae6 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4", + display_name="ByteDance Seedream 4.5", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ From 178bdc5e14ec0a55e401c509719e33773cc9b565 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:40:42 -0700 Subject: [PATCH 110/148] Add handling for vace_context in context windows (#11386) Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 1e0f86026..2f82d51da 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC): audio_cond = cond_value.cond if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # Handle vace_context (temporal dim is 3) + elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + vace_cond = cond_value.cond + if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim): + sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list) + new_cond_item[cond_key] = cond_value._copy_with(sliced_vace) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ From f59f71cf34067d46713f6243312f7f0b360d061f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 Dec 2025 22:41:22 -0500 Subject: [PATCH 111/148] ComfyUI version v0.7.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1f28e2407..1ed60fe5c 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/pyproject.toml b/pyproject.toml index 35a268bd1..bc1467941 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.6.0" +version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 0357ed7ec4a1bbfe3832874ad6cfc1ca3db1bc0d Mon Sep 17 00:00:00 2001 From: mengqin Date: Tue, 30 Dec 2025 17:53:52 -1000 Subject: [PATCH 112/148] Add support for sage attention 3 in comfyui, enable via new cli arg (#11026) * Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 * Fix some bugs found in PR review. The N dimension at which Sage Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale). * Remove the Sage Attention3 switch, but retain the attention function registration. * Fix a ruff check issue in attention.py --- comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..ccf690945 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,13 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): From 0be8a76c933026011098d41e61cc6e544739e427 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 30 Dec 2025 20:09:55 -0800 Subject: [PATCH 113/148] V3 Improvements + DynamicCombo + Autogrow exposed in public API (#11345) * Support Combo outputs in a more sane way * Remove test validate_inputs function on test node * Make curr_prefix be a list of strings instead of string for easier parsing as keys get added to dynamic types * Start to account for id prefixes from frontend, need to fix bug with nested dynamics * Ensure inputs/outputs/hidden are lists in schema finalize function, remove no longer needed 'is not None' checks * Add raw_link and extra_dict to all relevant Inputs * Make nested DynamicCombos work properly with prefixed keys on latest frontend; breaks old Autogrow, but is pretty much ready for upcoming Autogrow keys * Replace ... usage with a MISSING sentinel for clarity in nodes_logic.py * Added CustomCombo node in backend to reflect frontend node * Prepare Autogrow's expand_schema_for_dynamic to work with upcoming frontend changes * Prepare for look up table for dynamic input stuff * More progress towards dynamic input lookup function stuff * Finished converting _expand_schema_for_dynamic to be done via lookup instead of OOP to guarantee working with process isolation, did refactoring to remove old implementation + cleaning INPUT_TYPES definition including v3 hidden definition * Change order of functions * Removed some unneeded functions after dynamic refactor * Make MatchType's output default displayname "MATCHTYPE" * Fix DynamicSlot get_all * Removed redundant code - dynamic stuff no longer happens in OOP way * Natively support AnyType (*) without __ne__ hacks * Remove stray code that made it in * Remove expand_schema_for_dynamic left over on DynamicInput class * get_dynamic() on DynamicInput/Output was not doing anything anymore, so removed it * Make validate_inputs validate combo input correctly * Temporarily comment out conversion to 'new' (9 month old) COMBO format in get_input_info * Remove refrences to resources feature scrapped from V3 * Expose DynamicCombo in public API * satisfy ruff after some code got commented out * Make missing input error prettier for dynamic types * Created a Switch2 node as a side-by-side test, will likely go with Switch2 as the initial switch node * Figured out Switch situation * Pass in v3_data in IsChangedCache.get function's fingerprint_inputs, add a from_v3_data helper method to HiddenHolder * Switch order of Switch and Soft Switch nodes in file * Temp test node for MatchType * Fix missing v3_data for v1 nodes in validation * For now, remove chacking duplicate id's for dynamic types * Add Resize Image/Mask node that thanks to MatchType+DynamicCombo is 16-nodes-in-1 * Made DynamicCombo references in DCTestNode use public interface * Add an AnyTypeTestNode * Make lazy status for specific inputs on DynamicInputs work by having the values of the dictionary for check_lazy_status be a tuple, where the second element is the key of the input that can be returned * Comment out test logic nodes * Make primitive float's step make more sense * Add (and leave commented out) some potential logic nodes * Change default crop option to "center" on Resize Image/Mask node * Changed copy.copy(d) to d.copy() * Autogrow is available in stable frontend, so exposing it in public API * Use outputs id as display_name if no display_name present, remove v3 outputs id restriction that made them have to have unique IDs from the inputs * Enable Custom Combo node as stable frontend now supports it * Make id properly act like display_name on outputs * Add Batch Images/Masks/Latents node * Comment out Batch Images/Masks/Latents node for now, as Autogrow has a bug with MatchType where top connection is disconnected upon refresh * Removed code for a couple test nodes in nodes_logic.py * Add Batch Images, Batch Masks, and Batch Latents nodes with Autogrow, deprecate old Batch Images + LatentBatch nodes --- comfy_api/latest/__init__.py | 1 - comfy_api/latest/_io.py | 370 ++++++++++++++------------ comfy_api/latest/_resources.py | 72 ----- comfy_execution/graph.py | 5 + comfy_execution/validation.py | 14 + comfy_extras/nodes_latent.py | 1 + comfy_extras/nodes_logic.py | 149 +++++++++-- comfy_extras/nodes_post_processing.py | 356 +++++++++++++++++++++++++ comfy_extras/nodes_primitive.py | 2 +- execution.py | 45 ++-- nodes.py | 1 + 11 files changed, 742 insertions(+), 274 deletions(-) delete mode 100644 comfy_api/latest/_resources.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index fab63c7df..b0fa14ff6 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui -# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ba0b95498..764fa8b2b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker from ._util import MESH, VOXEL, SVG as _SVG @@ -76,16 +75,6 @@ class NumberDisplay(str, Enum): slider = "slider" -class _StringIOType(str): - def __ne__(self, value: object) -> bool: - if self == "*" or value == "*": - return False - if not isinstance(value, str): - return True - a = frozenset(self.split(",")) - b = frozenset(value.split(",")) - return not (b.issubset(a) or a.issubset(b)) - class _ComfyType(ABC): Type = Any io_type: str = None @@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs): new_cls.__module__ = cls.__module__ new_cls.__doc__ = cls.__doc__ # assign ComfyType attributes, if needed - # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) - new_cls.io_type = _StringIOType(io_type) + new_cls.io_type = io_type if hasattr(new_cls, "Input") and new_cls.Input is not None: new_cls.Input.Parent = new_cls if hasattr(new_cls, "Output") and new_cls.Output is not None: @@ -165,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -173,6 +161,7 @@ class Input(_IO_V3): self.tooltip = tooltip self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} + self.rawLink = raw_link def as_dict(self): return prune_dict({ @@ -180,10 +169,11 @@ class Input(_IO_V3): "optional": self.optional, "tooltip": self.tooltip, "lazy": self.lazy, + "rawLink": self.rawLink, }) | prune_dict(self.extra_dict) def get_io_type(self): - return _StringIOType(self.io_type) + return self.io_type def get_all(self) -> list[Input]: return [self] @@ -194,8 +184,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -217,13 +207,14 @@ class Output(_IO_V3): def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id - self.display_name = display_name + self.display_name = display_name if display_name else id self.tooltip = tooltip self.is_output_list = is_output_list def as_dict(self): + display_name = self.display_name if self.display_name else self.id return prune_dict({ - "display_name": self.display_name, + "display_name": display_name, "tooltip": self.tooltip, "is_output_list": self.is_output_list, }) @@ -251,8 +242,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.label_on = label_on self.label_off = label_off self.default: bool @@ -271,8 +262,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -297,8 +288,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -323,8 +314,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -357,12 +348,14 @@ class Combo(ComfyTypeIO): image_folder: FolderType=None, remote: RemoteOptions=None, socketless: bool=None, + extra_dict=None, + raw_link: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -386,10 +379,6 @@ class Combo(ComfyTypeIO): super().__init__(id, display_name, tooltip, is_output_list) self.options = options if options is not None else [] - @property - def io_type(self): - return self.options - @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' @@ -398,8 +387,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + socketless: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -432,9 +421,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) @comfytype(io_type="MASK") @@ -787,7 +776,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -800,7 +789,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self._io_types = types @property @@ -854,8 +843,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.template = template def as_dict(self): @@ -866,6 +855,8 @@ class MatchType(ComfyTypeIO): class Output(Output): def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): + if not id and not display_name: + display_name = "MATCHTYPE" super().__init__(id, display_name, tooltip, is_output_list) self.template = template @@ -878,24 +869,30 @@ class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - def get_dynamic(self) -> list[Input]: - return [] - - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - pass + pass class DynamicOutput(Output, ABC): ''' Abstract class for dynamic output registration. ''' - def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) + pass - def get_dynamic(self) -> list[Output]: - return [] +def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]: + if prefix_list is None: + prefix_list = [] + if id is not None: + prefix_list = prefix_list + [id] + return prefix_list + +def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str: + assert not (prefix_list is None and id is None) + if prefix_list is None: + return id + elif id is not None: + prefix_list = prefix_list + [id] + return ".".join(prefix_list) @comfytype(io_type="COMFY_AUTOGROW_V3") class Autogrow(ComfyTypeI): @@ -932,14 +929,6 @@ class Autogrow(ComfyTypeI): def validate(self): self.input.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - real_inputs = [] - for name, input in self.cached_inputs.items(): - if name in live_inputs: - real_inputs.append(input) - add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, real_inputs, curr_prefix) - class TemplatePrefix(_AutogrowTemplate): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): super().__init__(input) @@ -984,22 +973,45 @@ class Autogrow(ComfyTypeI): "template": self.template.as_dict(), }) - def get_dynamic(self) -> list[Input]: - return self.template.get_all() - def get_all(self) -> list[Input]: return [self] + self.template.get_all() def validate(self): self.template.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - curr_prefix = f"{curr_prefix}{self.id}." - # need to remove self from expected inputs dictionary; replaced by template inputs in frontend - for inner_dict in d.values(): - if self.id in inner_dict: - del inner_dict[self.id] - self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + # NOTE: purposely do not include self in out_dict; instead use only the template inputs + # need to figure out names based on template type + is_names = ("names" in value[1]["template"]) + is_prefix = ("prefix" in value[1]["template"]) + input = value[1]["template"]["input"] + if is_names: + min = value[1]["template"]["min"] + names = value[1]["template"]["names"] + max = len(names) + elif is_prefix: + prefix = value[1]["template"]["prefix"] + min = value[1]["template"]["min"] + max = value[1]["template"]["max"] + names = [f"{prefix}{i}" for i in range(max)] + # need to create a new input based on the contents of input + template_input = None + for _, dict_input in input.items(): + # for now, get just the first value from dict_input + template_input = list(dict_input.values())[0] + new_dict = {} + for i, name in enumerate(names): + expected_id = finalize_prefix(curr_prefix, name) + if expected_id in live_inputs: + # required + if i < min: + type_dict = new_dict.setdefault("required", {}) + # optional + else: + type_dict = new_dict.setdefault("optional", {}) + type_dict[name] = template_input + parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") class DynamicCombo(ComfyTypeI): @@ -1022,23 +1034,6 @@ class DynamicCombo(ComfyTypeI): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - # check if dynamic input's id is in live_inputs - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - key = live_inputs[self.id] - selected_option = None - for option in self.options: - if option.key == key: - selected_option = option - break - if selected_option is not None: - add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) - - def get_dynamic(self) -> list[Input]: - return [input for option in self.options for input in option.inputs] - def get_all(self) -> list[Input]: return [self] + [input for option in self.options for input in option.inputs] @@ -1053,6 +1048,24 @@ class DynamicCombo(ComfyTypeI): for input in option.inputs: input.validate() + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + key = live_inputs[finalized_id] + selected_option = None + # get options from dict + options: list[dict[str, str | dict[str, Any]]] = value[1]["options"] + for option in options: + if option["key"] == key: + selected_option = option + break + if selected_option is not None: + parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + @comfytype(io_type="COMFY_DYNAMICSLOT_V3") class DynamicSlot(ComfyTypeI): Type = dict[str, Any] @@ -1075,17 +1088,8 @@ class DynamicSlot(ComfyTypeI): self.force_input = True self.slot.force_input = True - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) - - def get_dynamic(self) -> list[Input]: - return [self.slot] + self.inputs - def get_all(self) -> list[Input]: - return [self] + [self.slot] + self.inputs + return [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ @@ -1099,17 +1103,41 @@ class DynamicSlot(ComfyTypeI): for input in self.inputs: input.validate() -def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): - dynamic = d.setdefault("dynamic_paths", {}) - if self is not None: - dynamic[self.id] = f"{curr_prefix}{self.id}" - for i in inputs: - if not isinstance(i, DynamicInput): - dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + inputs = value[1]["inputs"] + parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + +DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} +def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): + DYNAMIC_INPUT_LOOKUP[io_type] = func + +def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]: + return DYNAMIC_INPUT_LOOKUP[io_type] + +def setup_dynamic_input_funcs(): + # Autogrow.Input + register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic) + # DynamicCombo.Input + register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) + # DynamicSlot.Input + register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + +if len(DYNAMIC_INPUT_LOOKUP) == 0: + setup_dynamic_input_funcs() class V3Data(TypedDict): hidden_inputs: dict[str, Any] + 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] + 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + create_dynamic_tuple: bool + 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1145,6 +1173,10 @@ class HiddenHolder: api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), ) + @classmethod + def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder: + return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None) + class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. @@ -1250,61 +1282,56 @@ class Schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' nested_inputs: list[Input] = [] - if self.inputs is not None: - for input in self.inputs: + for input in self.inputs: + if not isinstance(input, DynamicInput): nested_inputs.extend(input.get_all()) - input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] - output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_ids = [i.id for i in nested_inputs] + output_ids = [o.id for o in self.outputs] input_set = set(input_ids) output_set = set(output_ids) - issues = [] + issues: list[str] = [] # verify ids are unique per list if len(input_set) != len(input_ids): issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") if len(output_set) != len(output_ids): issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") - # verify ids are unique between lists - intersection = input_set & output_set - if len(intersection) > 0: - issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) # validate inputs and outputs - if self.inputs is not None: - for input in self.inputs: - input.validate() - if self.outputs is not None: - for output in self.outputs: - output.validate() + for input in self.inputs: + input.validate() + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # ensure inputs, outputs, and hidden are lists + if self.inputs is None: + self.inputs = [] + if self.outputs is None: + self.outputs = [] + if self.hidden is None: + self.hidden = [] # if is an api_node, will need key-related hidden if self.is_api_node: - if self.hidden is None: - self.hidden = [] if Hidden.auth_token_comfy_org not in self.hidden: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: - if self.hidden is None: - self.hidden = [] if Hidden.prompt not in self.hidden: self.hidden.append(Hidden.prompt) if Hidden.extra_pnginfo not in self.hidden: self.hidden.append(Hidden.extra_pnginfo) # give outputs without ids default ids - if self.outputs is not None: - for i, output in enumerate(self.outputs): - if output.id is None: - output.id = f"_{i}_{output.io_type}_" + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: - # NOTE: live_inputs will not be used anymore very soon and this will be done another way + def get_v1_info(self, cls) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs, live_inputs) + input = create_input_dict_v1(self.inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1384,33 +1411,54 @@ class Schema: ) return info +def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: + out_dict = { + "required": {}, + "optional": {}, + "dynamic_paths": {}, + } + d = d.copy() + # ignore hidden for parsing + hidden = d.pop("hidden", None) + parse_class_inputs(out_dict, live_inputs, d) + if hidden is not None and include_hidden: + out_dict["hidden"] = hidden + v3_data = {} + dynamic_paths = out_dict.pop("dynamic_paths", None) + if dynamic_paths is not None: + v3_data["dynamic_paths"] = dynamic_paths + return out_dict, hidden, v3_data -def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: +def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: + for input_type, inner_d in curr_dict.items(): + for id, value in inner_d.items(): + io_type = value[0] + if io_type in DYNAMIC_INPUT_LOOKUP: + # dynamic inputs need to be handled with lookup functions + dynamic_input_func = get_dynamic_input_func(io_type) + new_prefix = handle_prefix(curr_prefix, id) + dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix) + else: + # non-dynamic inputs get directly transferred + finalized_id = finalize_prefix(curr_prefix, id) + out_dict[input_type][finalized_id] = value + if curr_prefix: + out_dict["dynamic_paths"][finalized_id] = finalized_id + +def create_input_dict_v1(inputs: list[Input]) -> dict: input = { "required": {} } - add_to_input_dict_v1(input, inputs, live_inputs) + for i in inputs: + add_to_dict_v1(i, input) return input -def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): - for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, d) - if live_inputs is not None: - i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) - else: - add_to_dict_v1(i, d) - -def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): +def add_to_dict_v1(i: Input, d: dict): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - if dynamic_dict is None: - value = (i.get_io_type(), as_dict) - else: - value = (i.get_io_type(), as_dict, dynamic_dict) - d.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) @@ -1422,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): values = values.copy() result = {} + create_tuple = v3_data.get("create_dynamic_tuple", False) + for key, path in paths.items(): parts = path.split(".") current = result @@ -1430,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): is_last = (i == len(parts) - 1) if is_last: - current[p] = values.pop(key, None) + value = values.pop(key, None) + if create_tuple: + value = (value, key) + current[p] = value else: current = current.setdefault(p, {}) @@ -1445,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): SCHEMA = None # filled in during execution - resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -1492,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): return [name for name in kwargs if kwargs[name] is None] def __init__(self): - self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -1560,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) + type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone @final @@ -1677,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + def INPUT_TYPES(cls) -> dict[str, dict]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls, live_inputs) - input = info.input - if not include_hidden: - input.pop("hidden", None) - if return_schema: - v3_data: V3Data = {} - dynamic = input.pop("dynamic_paths", None) - if dynamic is not None: - v3_data["dynamic_paths"] = dynamic - return input, schema, v3_data - return input + info = schema.get_v1_info(cls) + return info.input @final @classmethod @@ -1808,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal): return self.args if len(self.args) > 0 else None @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + def from_dict(cls, data: dict[str, Any]) -> NodeOutput: args = () ui = None expand = None @@ -1903,8 +1945,8 @@ __all__ = [ "Tracks", # Dynamic Types "MatchType", - # "DynamicCombo", - # "Autogrow", + "DynamicCombo", + "Autogrow", # Other classes "HiddenHolder", "Hidden", diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py deleted file mode 100644 index a6bdda972..000000000 --- a/comfy_api/latest/_resources.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations -import comfy.utils -import folder_paths -import logging -from abc import ABC, abstractmethod -from typing import Any -import torch - -class ResourceKey(ABC): - Type = Any - def __init__(self): - ... - -class TorchDictFolderFilename(ResourceKey): - '''Key for requesting a torch file via file_name from a folder category.''' - Type = dict[str, torch.Tensor] - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TorchDictFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get(self, key: ResourceKey, default: Any=...) -> Any: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, Any] = {} - - def get(self, key: ResourceKey, default: Any=...) -> Any: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, TorchDictFolderFilename): - if default is ...: - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - else: - full_path = folder_paths.get_full_path(key.folder_name, key.file_name) - if full_path is not None: - to_return = comfy.utils.load_torch_file(full_path, safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - if default is not ...: - return default - raise Exception(f"Unsupported resource key type: {type(key)}") - - -class _RESOURCES: - ResourceKey = ResourceKey - TorchDictFolderFilename = TorchDictFolderFilename - Resources = Resources - ResourcesLocal = ResourcesLocal diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..8fc5846b7 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -97,6 +97,11 @@ def get_input_info( extra_info = input_info[1] else: extra_info = {} + # if input_type is a list, it is a Combo defined in outdated format; convert it. + # NOTE: uncomment this when we are confident old format going away won't cause too much trouble. + # if isinstance(input_type, list): + # extra_info["options"] = input_type + # input_type = IO.Combo.io_type return input_type, input_category, extra_info class TopologicalSort: diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index 24c0b4ed7..e73624bd1 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -21,14 +21,24 @@ def validate_node_input( """ # If the types are exactly the same, we can return immediately # Use pre-union behaviour: inverse of `__ne__` + # NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class. if not received_type != input_type: return True + # If one of the types is '*', we can return True immediately; this is the 'Any' type. + if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type: + return True + # If the received type or input_type is a MatchType, we can return True immediately; # validation for this is handled by the frontend if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: return True + # This accounts for some custom nodes that output lists of options as the type; + # if we ever want to break them on purpose, this can be removed + if isinstance(received_type, list) and input_type == IO.Combo.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False @@ -37,6 +47,10 @@ def validate_node_input( received_types = set(t.strip() for t in received_type.split(",")) input_types = set(t.strip() for t in input_type.split(",")) + # If any of the types is '*', we can return True immediately; this is the 'Any' type. + if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types: + return True + if strict: # In strict mode, all received types must be in the input types return received_types.issubset(input_types) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2815c5ffc..9ba1c4ba8 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode): return io.Schema( node_id="LatentBatch", category="latent/batch", + is_deprecated=True, inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95a6ba788..eb888316a 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.latest import _io +# sentinel for missing inputs +MISSING = object() class SwitchNode(io.ComfyNode): @@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode): display_name="Switch", category="logic", is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class SoftSwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySoftSwitchNode", + display_name="Soft Switch", + category="logic", + is_experimental=True, inputs=[ io.Boolean.Input("switch"), io.MatchType.Input("on_false", template=template, lazy=True, optional=True), @@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode): ) @classmethod - def check_lazy_status(cls, switch, on_false=..., on_true=...): - # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING): + # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs. # This trick allows us to ignore the value of the switch and still be able to run execute(). # One of the inputs may be missing, in which case we need to evaluate the other input - if on_false is ...: + if on_false is MISSING: return ["on_true"] - if on_true is ...: + if on_true is MISSING: return ["on_false"] # Normal lazy switch operation if switch and on_true is None: @@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode): return ["on_false"] @classmethod - def validate_inputs(cls, switch, on_false=..., on_true=...): + def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING): # This check happens before check_lazy_status(), so we can eliminate the case where # both inputs are missing. - if on_false is ... and on_true is ...: + if on_false is MISSING and on_true is MISSING: return "At least one of on_false or on_true must be connected to Switch node" return True @classmethod - def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: - if on_true is ...: + def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput: + if on_true is MISSING: return io.NodeOutput(on_false) - if on_false is ...: + if on_false is MISSING: return io.NodeOutput(on_true) return io.NodeOutput(on_true if switch else on_false) +class CustomComboNode(io.ComfyNode): + """ + Frontend node that allows user to write their own options for a combo. + This is here to make sure the node has a backend-representation to avoid some annoyances. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CustomCombo", + display_name="Custom Combo", + category="utils", + is_experimental=True, + inputs=[io.Combo.Input("choice", options=[])], + outputs=[io.String.Output()] + ) + + @classmethod + def validate_inputs(cls, choice: io.Combo.Type) -> bool: + # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. + # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. + # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. + return True + + @classmethod + def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(choice) + + class DCTestNode(io.ComfyNode): class DCValues(TypedDict): combo: str @@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode): display_name="DCTest", category="logic", is_output_node=True, - inputs=[_io.DynamicCombo.Input("combo", options=[ - _io.DynamicCombo.Option("option1", [io.String.Input("string")]), - _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), - _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), - _io.DynamicCombo.Option("option4", [ - _io.DynamicCombo.Input("subcombo", options=[ - _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), - _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + inputs=[io.DynamicCombo.Input("combo", options=[ + io.DynamicCombo.Option("option1", [io.String.Input("string")]), + io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + io.DynamicCombo.Option("option4", [ + io.DynamicCombo.Input("subcombo", options=[ + io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), ]) ])] )], @@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode): combined = ",".join([str(x) for x in vals]) return io.NodeOutput(combined) +class ComboOutputTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComboOptionTestNode", + display_name="ComboOptionTest", + category="logic", + inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), + io.Combo.Input("combo2", options=["option4", "option5", "option6"])], + outputs=[io.Combo.Output(), io.Combo.Output()], + ) + + @classmethod + def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(combo, combo2) + +class ConvertStringToComboNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertStringToComboNode", + display_name="Convert String to Combo", + category="logic", + inputs=[io.String.Input("string")], + outputs=[io.Combo.Output()], + ) + + @classmethod + def execute(cls, string: str) -> io.NodeOutput: + return io.NodeOutput(string) + +class InvertBooleanNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="InvertBooleanNode", + display_name="Invert Boolean", + category="logic", + inputs=[io.Boolean.Input("boolean")], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, boolean: bool) -> io.NodeOutput: + return io.NodeOutput(not boolean) + class LogicExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - # SwitchNode, + SwitchNode, + CustomComboNode, + # SoftSwitchNode, + # ConvertStringToComboNode, # DCTestNode, # AutogrowNamesTestNode, # AutogrowPrefixTestNode, + # ComboOutputTestNode, + # InvertBooleanNode, ] async def comfy_entrypoint() -> LogicExtension: diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ca2cdeb50..01afa13a1 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,11 +4,15 @@ import torch import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management +from comfy_extras.nodes_latent import reshape_latent_to import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod @@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode): s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width > height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +class ResizeImageMaskNode(io.ComfyNode): + + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + category="transform", + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, + ]), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + raise ValueError(f"Unsupported resize type: {selected_type}") + +def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: + if len(images) == 0: + return None + # first, get the max channels count + max_channels = max(image.shape[-1] for image in images) + # then, pad all images to have the same channels count + padded_images: list[torch.Tensor] = [] + for image in images: + if image.shape[-1] < max_channels: + padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0)) + else: + padded_images.append(image) + # resize all images to be the same size as the first image + resized_images: list[torch.Tensor] = [] + first_image_shape = padded_images[0].shape + for image in padded_images: + if image.shape[1:] != first_image_shape[1:]: + resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1)) + else: + resized_images.append(image) + # batch the images in the format [b, h, w, c] + return torch.cat(resized_images, dim=0) + +def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None: + if len(masks) == 0: + return None + # resize all masks to be the same size as the first mask + resized_masks: list[torch.Tensor] = [] + first_mask_shape = masks[0].shape + for mask in masks: + if mask.shape[1:] != first_mask_shape[1:]: + mask = init_image_mask_input(mask, is_type_image=False) + mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center") + resized_masks.append(finalize_image_mask_input(mask, is_type_image=False)) + else: + resized_masks.append(mask) + # batch the masks in the format [b, h, w] + return torch.cat(resized_masks, dim=0) + +def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None: + if len(latents) == 0: + return None + samples_out = latents[0].copy() + samples_out["batch_index"] = [] + first_samples = latents[0]["samples"] + tensors: list[torch.Tensor] = [] + for latent in latents: + # first, deal with latent tensors + tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False)) + # next, deal with batch_index + samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])])) + samples_out["samples"] = torch.cat(tensors, dim=0) + return samples_out + +class BatchImagesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50) + return io.Schema( + node_id="BatchImagesNode", + display_name="Batch Images", + category="image", + inputs=[ + io.Autogrow.Input("images", template=autogrow_template) + ], + outputs=[ + io.Image.Output() + ] + ) + + @classmethod + def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_images(list(images.values()))) + +class BatchMasksNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) + return io.Schema( + node_id="BatchMasksNode", + display_name="Batch Masks", + category="mask", + inputs=[ + io.Autogrow.Input("masks", template=autogrow_template) + ], + outputs=[ + io.Mask.Output() + ] + ) + + @classmethod + def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_masks(list(masks.values()))) + +class BatchLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) + return io.Schema( + node_id="BatchLatentsNode", + display_name="Batch Latents", + category="latent", + inputs=[ + io.Autogrow.Input("latents", template=autogrow_template) + ], + outputs=[ + io.Latent.Output() + ] + ) + + @classmethod + def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_latents(list(latents.values()))) + +class BatchImagesMasksLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent]) + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("input", matchtype_template), + prefix="input", min=1, max=50) + return io.Schema( + node_id="BatchImagesMasksLatentsNode", + display_name="Batch Images/Masks/Latents", + category="util", + inputs=[ + io.Autogrow.Input("inputs", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(id=None, template=matchtype_template) + ] + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + batched = None + values = list(inputs.values()) + # latents + if isinstance(values[0], dict): + batched = batch_latents(values) + # images + elif is_image(values[0]): + batched = batch_images(values) + # masks + else: + batched = batch_masks(values) + return io.NodeOutput(batched) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension): Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, + BatchImagesNode, + BatchMasksNode, + BatchLatentsNode, + # BatchImagesMasksLatentsNode, ] async def comfy_entrypoint() -> PostProcessingExtension: diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 5a1aeba80..937321800 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -66,7 +66,7 @@ class Float(io.ComfyNode): display_name="Float", category="utils/primitive", inputs=[ - io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], outputs=[io.Float.Output()], ) diff --git a/execution.py b/execution.py index 0c239efd7..38159b1f4 100644 --- a/execution.py +++ b/execution.py @@ -79,7 +79,7 @@ class IsChangedCache: # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: @@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} + hidden_inputs_v3 = {} + valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) - else: - valid_inputs = class_def.INPUT_TYPES() + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) input_data_all = {} missing_keys = {} - hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [input_data] if is_v3: - if schema.hidden: - if io.Hidden.prompt in schema.hidden: + if hidden is not None: + if io.Hidden.prompt.name in hidden: hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} - if io.Hidden.dynprompt in schema.hidden: + if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt - if io.Hidden.extra_pnginfo in schema.hidden: + if io.Hidden.extra_pnginfo.name in hidden: hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) - if io.Hidden.unique_id in schema.hidden: + if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id - if io.Hidden.auth_token_comfy_org in schema.hidden: + if io.Hidden.auth_token_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) - if io.Hidden.api_key_comfy_org in schema.hidden: + if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) else: if "hidden" in valid_inputs: @@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no resources or state, just create clone + # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() @@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) + # for check_lazy_status, the returned data should include the original key of the input + v3_data_lazy = v3_data.copy() + v3_data_lazy["create_dynamic_tuple"] = True + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -756,10 +758,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors = [] valid = True + v3_data = None validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): - class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + obj_class: _io._ComfyNodeBaseInternal + class_inputs = obj_class.INPUT_TYPES() + class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: @@ -779,10 +784,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): assert extra_info is not None if x not in inputs: if input_category == "required": + details = f"{x}" if not v3_data else x.split(".")[-1] error = { "type": "required_input_missing", "message": "Required input is missing", - "details": f"{x}", + "details": details, "extra_info": { "input_name": x } @@ -916,8 +922,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if isinstance(input_type, list): - combo_options = input_type + if isinstance(input_type, list) or input_type == io.Combo.io_type: + if input_type == io.Combo.io_type: + combo_options = extra_info.get("options", []) + else: + combo_options = input_type if val not in combo_options: input_config = info list_info = "" diff --git a/nodes.py b/nodes.py index 7d83ecb21..d9e4ebd91 100644 --- a/nodes.py +++ b/nodes.py @@ -1863,6 +1863,7 @@ class ImageBatch: FUNCTION = "batch" CATEGORY = "image" + DEPRECATED = True def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: From 6ca3d5c011bc15737131eb665939ae0a39a74254 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:12:38 +0200 Subject: [PATCH 114/148] fix(api-nodes-vidu): preserve percent-encoding for signed URLs (#11564) --- comfy_api_nodes/util/_helpers.py | 20 ++++++++++++++++++++ comfy_api_nodes/util/download_helpers.py | 3 ++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 491e6b6a8..648defe3d 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,16 +1,22 @@ import asyncio import contextlib import os +import re import time from collections.abc import Callable from io import BytesIO +from yarl import URL + from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted +_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits +_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits + def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" @@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) + + +def to_aiohttp_url(url: str) -> URL: + """If `url` appears to be already percent-encoded (contains at least one valid %HH + escape and no malformed '%' sequences) and contains no raw whitespace/control + characters preserve the original encoding byte-for-byte (important for signed/presigned URLs). + Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed.""" + if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url): + # Avoid encoded=True if URL contains raw whitespace/control chars + return URL(url) + if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url): + # Preserve encoding only if it appears pre-encoded AND has no invalid % sequences + return URL(url, encoded=True) + return URL(url) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 3e0d0352d..4668d14a9 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -19,6 +19,7 @@ from ._helpers import ( get_auth_header, is_processing_interrupted, sleep_with_interrupt, + to_aiohttp_url, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted @@ -94,7 +95,7 @@ async def download_url_to_bytesio( monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(url, headers=headers)) + req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: From 236b9e211d5093b33acbe1918f56a6bfb4a5cf17 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 1 Jan 2026 05:38:39 +0800 Subject: [PATCH 115/148] chore: update workflow templates to v0.7.65 (#11579) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8b670b813..3a05799eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.64 +comfyui-workflow-templates==0.7.65 comfyui-embedded-docs==0.3.1 torch torchsde From d622a618749b603531b753cef286a6051dd85565 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 31 Dec 2025 14:38:36 -0800 Subject: [PATCH 116/148] Refactor: move clip_preprocess to comfy.clip_model (#11586) --- comfy/clip_model.py | 19 +++++++++++++++++++ comfy/clip_vision.py | 22 ++-------------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5..e88872728 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -2,6 +2,25 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 447b1ce4a..d5fc53497 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os -import torch import json import logging @@ -17,24 +16,7 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): - image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) - std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) - if not (image.shape[2] == size and image.shape[3] == size): - if crop: - scale = (size / min(image.shape[2], image.shape[3])) - scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) - else: - scale_size = (size, size) - - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] - image = torch.clip((255. * image), 0, 255).round() / 255.0 - return (image - mean.view([3,1,1])) / std.view([3,1,1]) +clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, @@ -73,7 +55,7 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() From 1bdc9a947f578733f81c9ae894a5acd5809c7a66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 31 Dec 2025 16:29:55 -0800 Subject: [PATCH 117/148] Remove duplicate import of model_management (#11587) --- comfy/text_encoders/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ed29e014d..faa4e1de8 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -8,7 +8,6 @@ from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit -import comfy.model_management from . import qwen_vl @dataclass From 65cfcf5b1bb0d0618fef7bee08ee64397be5c434 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:06:14 -0800 Subject: [PATCH 118/148] New Year ruff cleanup. (#11595) --- app/model_manager.py | 4 ++-- comfy/hooks.py | 3 ++- comfy/ldm/chroma_radiance/model.py | 2 +- comfy/ldm/hunyuan_video/upsampler.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/ema.py | 4 ++-- comfy/ldm/util.py | 2 +- comfy/taesd/taehv.py | 6 ++++-- comfy_execution/graph.py | 6 +++--- comfy_extras/nodes_apg.py | 3 ++- comfy_extras/nodes_wan.py | 2 +- nodes.py | 6 ++++-- pyproject.toml | 4 ++++ server.py | 6 +++--- 14 files changed, 35 insertions(+), 22 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index ab36bca74..f124d1117 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -44,7 +44,7 @@ class ModelFileManager: @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @@ -55,7 +55,7 @@ class ModelFileManager: path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) - if not folder_name in folder_paths.folder_names_and_paths: + if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/comfy/hooks.py b/comfy/hooks.py index 9d0731072..1a76c7ba4 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -527,7 +527,8 @@ class HookKeyframeGroup: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further - else: break + else: + break # update steps current context is used self._current_used_steps += 1 # update current timestep this was performed on diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d173889..4fb56165e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -270,7 +270,7 @@ class ChromaRadiance(Chroma): bad_keys = tuple( k for k, v in overrides.items() - if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 85f515f67..d9e76922f 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,7 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management, model_patcher +import model_management +import model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5..1ae3ef034 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -394,7 +394,8 @@ class Model(nn.Module): attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -548,7 +549,8 @@ class Encoder(nn.Module): conv3d=False, time_compress=None, **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..96ee6e895 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -45,7 +45,7 @@ class LitEma(nn.Module): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -54,7 +54,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b4721..304936ff4 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -71,7 +71,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 3dfe1e4d4..0e5f9a378 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -154,7 +154,8 @@ class TAEHV(nn.Module): self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] if x.shape[1] % 4 != 0: # pad at end to multiple of 4 @@ -167,5 +168,6 @@ class TAEHV(nn.Module): def decode(self, x, **kwargs): x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 8fc5846b7..9d170b16e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -207,15 +207,15 @@ class ExecutionList(TopologicalSort): return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index f27ae7da8..b9df2dcc9 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -55,7 +55,8 @@ class APG(io.ComfyNode): def pre_cfg_function(args): nonlocal running_avg, prev_sigma - if len(args["conds_out"]) == 1: return args["conds_out"] + if len(args["conds_out"]) == 1: + return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index b0bd471bf..d32aad98e 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -817,7 +817,7 @@ def get_sample_indices(original_fps, if required_duration > total_frames / original_fps: raise ValueError("required_duration must be less than video length") - if not fixed_start is None and fixed_start >= 0: + if fixed_start is not None and fixed_start >= 0: start_frame = fixed_start else: max_start = total_frames - required_origin_frames diff --git a/nodes.py b/nodes.py index d9e4ebd91..eae2f0086 100644 --- a/nodes.py +++ b/nodes.py @@ -2242,8 +2242,10 @@ async def init_external_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": + continue + if module_path.endswith(".disabled"): + continue if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue diff --git a/pyproject.toml b/pyproject.toml index bc1467941..60378de1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,16 @@ lint.select = [ "N805", # invalid-first-argument-name-for-method "S307", # suspicious-eval-usage "S102", # exec + "E", "T", # print-usage "W", # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] + +lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"] + exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.pylint] diff --git a/server.py b/server.py index c27f8be7d..70c8b5e3b 100644 --- a/server.py +++ b/server.py @@ -324,7 +324,7 @@ class PromptServer(): @routes.get("/models/{folder}") async def get_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = folder_paths.get_filename_list(folder) return web.json_response(files) @@ -579,7 +579,7 @@ class PromptServer(): folder_name = request.match_info.get("folder_name", None) if folder_name is None: return web.Response(status=404) - if not "filename" in request.rel_url.query: + if "filename" not in request.rel_url.query: return web.Response(status=404) filename = request.rel_url.query["filename"] @@ -593,7 +593,7 @@ class PromptServer(): if out is None: return web.Response(status=404) dt = json.loads(out) - if not "__metadata__" in dt: + if "__metadata__" not in dt: return web.Response(status=404) return web.json_response(dt["__metadata__"]) From 9e5f677746463228e35ac6a08f308d758ed620d5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:35:34 +0200 Subject: [PATCH 119/148] Ignore all frames except the first one for MPO format. (#11569) --- nodes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index eae2f0086..662907ae6 100644 --- a/nodes.py +++ b/nodes.py @@ -1663,8 +1663,6 @@ class LoadImage: output_masks = [] w, h = None, None - excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) @@ -1692,7 +1690,10 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1 and img.format not in excluded_formats: + if img.format == "MPO": + break # ignore all frames except the first one for MPO format + + if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: From 303b1735f8785c0d1f947af965567850ca413f61 Mon Sep 17 00:00:00 2001 From: throttlekitty Date: Fri, 2 Jan 2026 01:37:37 -0700 Subject: [PATCH 120/148] Give Mahiro CFG a more appropriate display name (#11580) --- comfy_extras/nodes_mahiro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 07b3353f4..6459ca8c1 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + display_name="Mahiro CFG", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ From f2fda021ab179ba31d9175698b82474a5dd14359 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 2 Jan 2026 13:18:43 +0200 Subject: [PATCH 121/148] Tripo3D: pass face_limit parameter only when it differs from default (#11601) --- comfy_api_nodes/nodes_tripo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index bd3c24fb3..e72f8e96a 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, geometry_quality=geometry_quality, auto_size=True, quad=quad, @@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode): texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, auto_size=True, quad=quad, ), @@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): texture_quality=texture_quality, geometry_quality=geometry_quality, texture_alignment=texture_alignment, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, quad=quad, ), ) From 9a552df898ec57f066784cc1f7c475644099b3c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 2 Jan 2026 17:28:10 -0800 Subject: [PATCH 122/148] Remove leftover scaled_fp8 key. (#11603) --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 8d4e2b445..e4162d7ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): out_sd = {} layers = {} for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue From 53e762a3af9502ebe61a60eb2d39d783fe8d012b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:28:38 -0800 Subject: [PATCH 123/148] Print memory summary on OOM to help with debugging. (#11613) --- comfy/model_management.py | 4 ++++ execution.py | 1 + 2 files changed, 5 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 87baedd73..2501cecb7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1542,6 +1542,10 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def debug_memory_summary(): + if is_amd() or is_nvidia(): + return torch.cuda.memory.memory_summary() + return "" #TODO: might be cleaner to put this somewhere else import threading diff --git a/execution.py b/execution.py index 38159b1f4..648f204ec 100644 --- a/execution.py +++ b/execution.py @@ -601,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if isinstance(ex, comfy.model_management.OOM_EXCEPTION): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." + logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() From acbf08cd60fade74b2e9e5009fa0dcad9538356b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 4 Jan 2026 09:05:02 +0200 Subject: [PATCH 124/148] feat(api-nodes): add support for 720p resolution for Kling Omni nodes (#11604) --- comfy_api_nodes/nodes_kling.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 58259e029..9c707a339 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -807,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("duration", options=[5, 10]), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -826,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt: str, aspect_ratio: str, duration: int, + resolution: str = "1080p", ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=2500) response = await sync_op( @@ -837,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -872,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): optional=True, tooltip="Up to 6 additional reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -893,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): first_frame: Input.Image, end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -936,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -964,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): "reference_images", tooltip="Up to 7 reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -984,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio: str, duration: int, reference_images: Input.Image, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1005,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1036,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1058,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): reference_video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1090,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): duration=str(duration), image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1119,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1139,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode): video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1171,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode): duration=None, image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) From 38d049382533c6662d815b08ca3395e96cca9f57 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 16:13:50 -0800 Subject: [PATCH 125/148] Fix case where upscale model wouldn't be moved to cpu. (#11633) --- comfy_extras/nodes_upscale_model.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4d62b87be..ed587851c 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode): overlap = 32 oom = True - while oom: - try: - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) - oom = False - except model_management.OOM_EXCEPTION as e: - tile //= 2 - if tile < 128: - raise e + try: + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e + finally: + upscale_model.to("cpu") - upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return io.NodeOutput(s) From f2b002372b71cf0671a4cf1fa539e1c386d727e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 22:58:59 -0800 Subject: [PATCH 126/148] Support the LTXV 2 model. (#11632) --- comfy/latent_formats.py | 3 + comfy/ldm/lightricks/av_model.py | 837 ++++++++++++++++ comfy/ldm/lightricks/embeddings_connector.py | 305 ++++++ comfy/ldm/lightricks/latent_upsampler.py | 292 ++++++ comfy/ldm/lightricks/model.py | 715 +++++++++++--- comfy/ldm/lightricks/symmetric_patchifier.py | 87 +- comfy/ldm/lightricks/vae/audio_vae.py | 286 ++++++ .../vae/causal_audio_autoencoder.py | 909 ++++++++++++++++++ comfy/ldm/lightricks/vocoders/vocoder.py | 213 ++++ comfy/model_base.py | 57 +- comfy/model_detection.py | 2 +- comfy/sd.py | 9 +- comfy/supported_models.py | 17 +- comfy/text_encoders/llama.py | 79 ++ comfy/text_encoders/lt.py | 111 +++ comfy/utils.py | 2 +- comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_hunyuan.py | 15 +- comfy_extras/nodes_lt.py | 188 +++- comfy_extras/nodes_lt_audio.py | 183 ++++ comfy_extras/nodes_lt_upsampler.py | 75 ++ nodes.py | 10 +- pyproject.toml | 2 +- 23 files changed, 4214 insertions(+), 185 deletions(-) create mode 100644 comfy/ldm/lightricks/av_model.py create mode 100644 comfy/ldm/lightricks/embeddings_connector.py create mode 100644 comfy/ldm/lightricks/latent_upsampler.py create mode 100644 comfy/ldm/lightricks/vae/audio_vae.py create mode 100644 comfy/ldm/lightricks/vae/causal_audio_autoencoder.py create mode 100644 comfy/ldm/lightricks/vocoders/vocoder.py create mode 100644 comfy_extras/nodes_lt_audio.py create mode 100644 comfy_extras/nodes_lt_upsampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e..9bbe30b53 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -407,6 +407,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + pass + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000..759535501 --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,837 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + LTXVModel, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +import comfy.ldm.common_dit + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(6, a_dim, device=device, dtype=dtype) + ) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + v_context=None, + a_context=None, + attention_mask=None, + v_timestep=None, + a_timestep=None, + v_pe=None, + a_pe=None, + v_cross_pe=None, + a_cross_pe=None, + v_cross_scale_shift_timestep=None, + a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, + a_cross_gate_timestep=None, + transformer_options=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) + ) + + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa + vx += self.attn2( + comfy.ldm.common_dit.rms_norm(vx), + context=v_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) + ) + + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + ax += ( + self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + * agate_msa + ) + ax += self.audio_attn2( + comfy.ldm.common_dit.rms_norm(ax), + context=a_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + # norm3 + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + a_cross_scale_shift_timestep, + a_cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + v_cross_scale_shift_timestep, + v_cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + + shift_ca_video_hidden_states_a2v + ) + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + + shift_ca_audio_hidden_states_a2v + ) + vx += ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=v_cross_pe, + k_pe=a_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_a2v + ) + + del gate_out_a2v + del scale_ca_video_hidden_states_a2v,\ + shift_ca_video_hidden_states_a2v,\ + scale_ca_audio_hidden_states_a2v,\ + shift_ca_audio_hidden_states_a2v,\ + + if run_v2a: + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + + shift_ca_audio_hidden_states_v2a + ) + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + + shift_ca_video_hidden_states_v2a + ) + ax += ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=a_cross_pe, + k_pe=v_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_v2a + ) + + del gate_out_v2a + del scale_ca_video_hidden_states_v2a,\ + shift_ca_video_hidden_states_v2a,\ + scale_ca_audio_hidden_states_v2a,\ + shift_ca_audio_hidden_states_v2a + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) + ) + + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + vx += self.ff(vx_scaled) * vgate_mlp + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) + ) + + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + ax += self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) + v_embedded_timestep = v_embedded_timestep.view( + batch_size, -1, v_embedded_timestep.shape[-1] + ) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + if a_timestep is not None: + a_timestep = a_timestep * self.timestep_scale_multiplier + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + a_timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view( + batch_size, -1, a_embedded_timestep.shape[-1] + ) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ] + cross_av_timestep_ss = list( + [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] + ) + else: + a_timestep = timestep + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + + return [v_timestep, a_timestep, cross_av_timestep_ss], [ + v_embedded_timestep, + a_embedded_timestep, + ] + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + v_context, a_context = torch.split( + context, int(context.shape[-1] / 2), len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.audio_caption_projection is not None: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, ax.shape[-1]) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000..f7a43f3c3 --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,305 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + * 2.0 + - 1.0 + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers, (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000..78ed7653f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f..d61e19d6e 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,47 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +73,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module): post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module): embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -222,23 +282,68 @@ class GELU_approx(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +359,11 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,8 +373,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -277,14 +384,34 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) @@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module): return x def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - return freqs_cis + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) - caption_channels=4096, - num_layers=28, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ + + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels + + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device + ) self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, ) + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe + + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask + + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x + + +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + # No additional components needed for LTXV beyond base class + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, ) - for d in range(num_layers) + for _ in range(self.num_layers) ] ) + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) - - self.patchifier = SymmetricPatchifier(1) - - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) - - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - - orig_shape = list(x.shape) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module): causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + return x, pixel_coords, additional_args + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) @@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module): transformer_options=transformer_options, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9f..8f9a41186 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ class Patchifier(ABC): torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier): q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000..a9111d3bd --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,286 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + @staticmethod + def normalize_amplitude( + waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 + ) -> torch.Tensor: + waveform = waveform - waveform.mean(dim=2, keepdim=True) + peak = torch.max(torch.abs(waveform)) + eps + scale = peak.clamp(max=max_amplitude) / peak + return waveform * scale + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + waveform = self.normalize_amplitude(waveform) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000..f12b9bb53 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,909 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self._guess_config() + + # Extract encoder and decoder configs from the new format + model_config = config.get("model", {}).get("params", {}) + variables_config = config.get("variables", {}) + + self.sampling_rate = variables_config.get( + "sampling_rate", + model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def _guess_config(self): + encoder_config = { + # Required parameters - based on ltx-video-av-1679000 model metadata + "ch": 128, + "out_ch": 8, + "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] + "num_res_blocks": 2, + "attn_resolutions": [], # Based on metadata: empty list, no attention + "dropout": 0.0, + "resamp_with_conv": True, + "in_channels": 2, # stereo + "resolution": 256, + "z_channels": 8, + "double_z": True, + "attn_type": "vanilla", + "mid_block_add_attention": False, # Based on metadata: false + "norm_type": "pixel", + "causality_axis": "height", # Based on metadata + "mel_bins": 64, # Based on metadata: mel_bins = 64 + } + + decoder_config = { + # Inherits encoder config, can override specific params + **encoder_config, + "out_ch": 2, # Stereo audio output (2 channels) + "give_pre_end": False, + "tanh_out": False, + } + + config = { + "_class_name": "CausalAudioAutoencoder", + "sampling_rate": 16000, + "model": { + "params": { + "encoder": encoder_config, + "decoder": decoder_config, + } + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000..b1f15f2c5 --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import numpy as np + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + resblock = config.get("resblock", "1") + + self.output_sample_rate = config.get("output_sample_rate") + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + out_channels = 2 if stereo else 1 + self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [6, 5, 2, 2, 2], + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c4f3c0639..49efd700b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -946,7 +947,7 @@ class GenmoMochi(BaseModel): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -977,6 +978,60 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 539e296ed..0853b3aec 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 diff --git a/comfy/sd.py b/comfy/sd.py index 7de7dd9c6..32157e18b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1041,7 +1041,8 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 - JINA_CLIP_2 = 18 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 def detect_te_model(sd): @@ -1067,6 +1068,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B @@ -1271,6 +1274,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..ee9a79001 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.055 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index faa4e1de8..76731576b 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -7,6 +7,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit +import comfy.clip_model from . import qwen_vl @@ -188,6 +189,31 @@ class Gemma3_4B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + mm_tokens_per_image = 256 + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() @@ -520,6 +546,41 @@ class Llama2_(nn.Module): return x, intermediate + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + + class BaseLlama: def get_input_embeddings(self): return self.model.embed_tokens @@ -636,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_12B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67..2c2d453e8 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,11 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import torch +import comfy.utils class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +20,110 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): + text = llama_template.format(text) + text_tokens = super().tokenize_with_weights(text, return_word_ids) + embed_count = 0 + for k in text_tokens: + tt = text_tokens[k] + for r in tt: + for i in range(len(r)): + if r[i][0] == 262144: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return text_tokens + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def set_clip_options(self, options): + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out_device = out.device + out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + + return out.to(out_device), pooled + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + return self.load_state_dict(sdo, strict=False) + + +def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return LTXAVTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index e4162d7ac..ffa98c9b1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443c..94ad5e8a8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1..ceff657d3 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode): "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode): "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb..b91a22309 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode): generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +159,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode): frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, - t[:, :, :num_prefix_frames], + t, strength, scale_factors, ) - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( - latent_image, - noise_mask, - t, - latent_idx, - strength, - ) - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode): preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000..b0b7000ef --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,183 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000..f99ba13fb --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/nodes.py b/nodes.py index 662907ae6..56b74ebe3 100644 --- a/nodes.py +++ b/nodes.py @@ -295,7 +295,11 @@ class VAEDecode: DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - images = vae.decode(samples["samples"]) + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + images = vae.decode(latent) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -970,7 +974,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2331,6 +2335,8 @@ async def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_mahiro.py", + "nodes_lt_upsampler.py", + "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", diff --git a/pyproject.toml b/pyproject.toml index 60378de1e..a7d159be9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "ComfyUI" version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.9" +requires-python = ">=3.10" [project.urls] homepage = "https://www.comfy.org/" From d1b9822f741843c64b2cbd8e1bcdd49794b182ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 23:27:31 -0800 Subject: [PATCH 127/148] Add LTXAVTextEncoderLoader node. (#11634) --- comfy_extras/nodes_lt_audio.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index b0b7000ef..2d3d103b4 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -169,6 +169,38 @@ class LTXVEmptyLatentAudio(io.ComfyNode): ) +class LTXAVTextEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXAVTextEncoderLoader", + display_name="LTXV Audio Text Encoder Loader", + category="advanced/loaders", + description="[Recipes]\n\nltxav: gemma 3 12B", + inputs=[ + io.Combo.Input( + "text_encoder", + options=folder_paths.get_filename_list("text_encoders"), + ), + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + ) + ], + outputs=[io.Clip.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, text_encoder, ckpt_name, device="default"): + clip_type = comfy.sd.CLIPType.LTXV + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) + clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + return io.NodeOutput(clip) + + class LTXVAudioExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ @@ -176,6 +208,7 @@ class LTXVAudioExtension(ComfyExtension): LTXVAudioVAEEncode, LTXVAudioVAEDecode, LTXVEmptyLatentAudio, + LTXAVTextEncoderLoader, ] From d157c3299d6f9e1b57981bdb4931f1d7129e4e8d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 00:48:31 -0800 Subject: [PATCH 128/148] Refactor module_size function. (#11637) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2501cecb7..7f5a8aee9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -456,7 +456,7 @@ def module_size(module): sd = module.state_dict() for k in sd: t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem += t.nbytes return module_mem class LoadedModel: From 4f3f9e72a9d0c15d00c0c362b8e90f1db5af6cfb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 02:41:23 -0800 Subject: [PATCH 129/148] Fix name. (#11638) --- comfy_extras/nodes_lt_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 2d3d103b4..26b0160d2 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -187,7 +187,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode): options=folder_paths.get_filename_list("checkpoints"), ) ], - outputs=[io.Clip.Output(display_name="Audio VAE")], + outputs=[io.Clip.Output()], ) @classmethod From 6da00dd899e3ee6f2a0a8163b080a9f373395025 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:48:58 -0800 Subject: [PATCH 130/148] Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635) --------- Co-authored-by: Jedrzej Kosinski --- .github/workflows/test-build.yml | 2 +- .github/workflows/test-launch.yml | 4 +- comfy/model_management.py | 4 +- comfy/ops.py | 164 +++-- comfy/quant_ops.py | 641 +++--------------- requirements.txt | 1 + .../comfy_quant/test_mixed_precision.py | 12 +- tests-unit/comfy_quant/test_quant_registry.py | 190 ------ 8 files changed, 223 insertions(+), 795 deletions(-) delete mode 100644 tests-unit/comfy_quant/test_quant_registry.py diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..9160242e9 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index fd70aff23..ef0d3f123 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -32,7 +32,9 @@ jobs: working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | - if grep -qE "Exception|Error" console_output.log; then + grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log + cat console_output_filtered.log + if grep -qE "Exception|Error" console_output_filtered.log; then echo "Unhandled exception/error found in server log." exit 1 fi diff --git a/comfy/model_management.py b/comfy/model_management.py index 7f5a8aee9..22f4de044 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1156,7 +1156,7 @@ def pin_memory(tensor): if not tensor.is_contiguous(): return False - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1183,7 +1183,7 @@ def unpin_memory(tensor): return False ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes size_stored = PINNED_MEMORY.get(ptr, None) if size_stored is None: diff --git a/comfy/ops.py b/comfy/ops.py index 16889bb82..f5e1e9230 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -412,26 +412,34 @@ def fp8_linear(self, input): return None input_dtype = input.dtype + input_shape = input.shape + tensor_3d = input.ndim == 3 - if input.ndim == 3 or input.ndim == 2: - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) + if input.ndim != 2: + return None + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - # Wrap weight in QuantizedTensor - this enables unified dispatch - # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! - layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) - o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.to(dtype).contiguous() + layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) + quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) - uncast_bias_weight(self, w, bias, offload_stream) - return o + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - return None + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_3d: + o = o.reshape((input_shape[0], input_shape[1], w.shape[0])) + + return o class fp8_ops(manual_cast): class Linear(manual_cast.Linear): @@ -477,7 +485,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, QUANT_ALGOS +from .quant_ops import ( + QuantizedTensor, + QUANT_ALGOS, + TensorCoreFP8Layout, + get_layout_class, +) def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): @@ -497,14 +510,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - if dtype is None: - dtype = MixedPrecisionOps._compute_dtype - - self.factory_kwargs = {"device": device, "dtype": dtype} + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features - self._has_bias = bias + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm @@ -512,6 +526,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -529,14 +553,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - dtype = self.factory_kwargs["dtype"] - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) - if dtype != MixedPrecisionOps._compute_dtype: - self.comfy_cast_weights = True - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) - else: - self.register_parameter("bias", None) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) if not self._full_precision_mm: @@ -547,31 +564,46 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) - weight_scale_key = f"{prefix}weight_scale" - scale = state_dict.pop(weight_scale_key, None) - if scale is not None: - scale = scale.to(device) - layout_params = { - 'scale': scale, - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } + # Load format-specific parameters + if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: + # FP8: single tensor scale + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if scale is not None: - manually_loaded_keys.append(weight_scale_key) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "nvfp4": + # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.float8_e4m3fn) + + if tensor_scale is None or block_scale is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + else: + raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) - else: - self.register_parameter("bias", None) - for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue # Already handled above + param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -588,7 +620,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) if isinstance(self.weight, QuantizedTensor): - sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + layout_cls = self.weight._layout_cls + + # Check if it's any FP8 variant (E4M3 or E5M2) + if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): + sd["{}weight_scale".format(prefix)] = self.weight._params.scale + elif layout_cls == "TensorCoreNVFP4Layout": + sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale + sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + quant_conf = {"format": self.quant_format} if self._full_precision_mm: quant_conf["full_precision_matrix_mult"] = True @@ -607,12 +647,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def forward(self, input, *args, **kwargs): run_every_op() + input_shape = input.shape + tensor_3d = input.ndim == 3 + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) + + if input.ndim != 2: + # Fall back to comfy_cast_weights for non-2D tensors + return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + + # dtype is now implicit in the layout class + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) + + output = self._forward(input, self.weight, self.bias) + + # Reshape output back to 3D if input was 3D + if tensor_3d: + output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + + return output def convert_weight(self, weight, inplace=False, **kwargs): if isinstance(weight, QuantizedTensor): @@ -622,7 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + # dtype is now implicit in the layout class + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..cd737726f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,580 +1,133 @@ import torch import logging -from typing import Tuple, Dict + +try: + import comfy_kitchen as ck + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + TensorCoreFP8Layout as _CKFp8Layout, + TensorCoreNVFP4Layout, # Direct import, no wrapper needed + register_layout_op, + register_layout_class, + get_layout_class, + ) + _CK_AVAILABLE = True + ck.registry.disable("triton") + for k, v in ck.list_backends().items(): + logging.info(f"Found comfy_kitchen backend {k}: {v}") +except ImportError as e: + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class TensorCoreNVFP4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None + import comfy.float -_LAYOUT_REGISTRY = {} -_GENERIC_UTILS = {} - - -def register_layout_op(torch_op, layout_type): - """ - Decorator to register a layout-specific operation handler. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) - layout_type: Layout class (e.g., TensorCoreFP8Layout) - Example: - @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) - def fp8_linear(func, args, kwargs): - # FP8-specific linear implementation - ... - """ - def decorator(handler_func): - if torch_op not in _LAYOUT_REGISTRY: - _LAYOUT_REGISTRY[torch_op] = {} - _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func - return handler_func - return decorator - - -def register_generic_util(torch_op): - """ - Decorator to register a generic utility that works for all layouts. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - - Example: - @register_generic_util(torch.ops.aten.detach.default) - def generic_detach(func, args, kwargs): - # Works for any layout - ... - """ - def decorator(handler_func): - _GENERIC_UTILS[torch_op] = handler_func - return handler_func - return decorator - - -def _get_layout_from_args(args): - for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type - return None - - -def _move_layout_params_to_device(params, device): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.to(device=device) - else: - new_params[k] = v - return new_params - - -def _copy_layout_params(params): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.clone() - else: - new_params[k] = v - return new_params - -def _copy_layout_params_inplace(src, dst, non_blocking=False): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - dst[k].copy_(v, non_blocking=non_blocking) - else: - dst[k] = v - -class QuantizedLayout: - """ - Base class for quantization layouts. - - A layout encapsulates the format-specific logic for quantization/dequantization - and provides a uniform interface for extracting raw tensors needed for computation. - - New quantization formats should subclass this and implement the required methods. - """ - @classmethod - def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: - raise NotImplementedError(f"{cls.__name__} must implement quantize()") - - @staticmethod - def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError("TensorLayout must implement dequantize()") - - @classmethod - def get_plain_tensors(cls, qtensor) -> torch.Tensor: - raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") - - -class QuantizedTensor(torch.Tensor): - """ - Universal quantized tensor that works with any layout. - - This tensor subclass uses a pluggable layout system to support multiple - quantization formats (FP8, INT4, INT8, etc.) without code duplication. - - The layout_type determines format-specific behavior, while common operations - (detach, clone, to) are handled generically. - - Attributes: - _qdata: The quantized tensor data - _layout_type: Layout class (e.g., TensorCoreFP8Layout) - _layout_params: Dict with layout-specific params (scale, zero_point, etc.) - """ - - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. - - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) - - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params - - def __repr__(self): - layout_name = self._layout_type - param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) - return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - - @property - def layout_type(self): - return self._layout_type - - def __tensor_flatten__(self): - """ - Tensor flattening protocol for proper device movement. - """ - inner_tensors = ["_qdata"] - ctx = { - "layout_type": self._layout_type, - } - - tensor_params = {} - non_tensor_params = {} - for k, v in self._layout_params.items(): - if isinstance(v, torch.Tensor): - tensor_params[k] = v - else: - non_tensor_params[k] = v - - ctx["tensor_param_keys"] = list(tensor_params.keys()) - ctx["non_tensor_params"] = non_tensor_params - - for k, v in tensor_params.items(): - attr_name = f"_layout_param_{k}" - object.__setattr__(self, attr_name, v) - inner_tensors.append(attr_name) - - return inner_tensors, ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): - """ - Tensor unflattening protocol for proper device movement. - Reconstructs the QuantizedTensor after device movement. - """ - layout_type = ctx["layout_type"] - layout_params = dict(ctx["non_tensor_params"]) - - for key in ctx["tensor_param_keys"]: - attr_name = f"_layout_param_{key}" - layout_params[key] = inner_tensors[attr_name] - - return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) - - @classmethod - def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) - return cls(qdata, layout_type, layout_params) - - def dequantize(self) -> torch.Tensor: - return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Step 1: Check generic utilities first (detach, clone, to, etc.) - if func in _GENERIC_UTILS: - return _GENERIC_UTILS[func](func, args, kwargs) - - # Step 2: Check layout-specific handlers (linear, matmul, etc.) - layout_type = _get_layout_from_args(args) - if layout_type and func in _LAYOUT_REGISTRY: - handler = _LAYOUT_REGISTRY[func].get(layout_type) - if handler: - return handler(func, args, kwargs) - - # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) - - @classmethod - def _dequant_and_fallback(cls, func, args, kwargs): - def dequant_arg(arg): - if isinstance(arg, QuantizedTensor): - return arg.dequantize() - elif isinstance(arg, (list, tuple)): - return type(arg)(dequant_arg(a) for a in arg) - return arg - - new_args = dequant_arg(args) - new_kwargs = dequant_arg(kwargs) - return func(*new_args, **new_kwargs) - - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self, *arg, **kwargs): - return self._qdata.is_contiguous(*arg, **kwargs) - - def storage(self): - return self._qdata.storage() - # ============================================================================== -# Generic Utilities (Layout-Agnostic Operations) +# FP8 Layouts with Comfy-Specific Extensions # ============================================================================== -def _create_transformed_qtensor(qt, transform_fn): - new_data = transform_fn(qt._qdata) - new_params = _copy_layout_params(qt._layout_params) - return QuantizedTensor(new_data, qt._layout_type, new_params) +class _TensorCoreFP8LayoutBase(_CKFp8Layout): + FP8_DTYPE = None # Must be overridden in subclass - -def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_layout is not None and target_layout != torch.strided: - logging.warning( - f"QuantizedTensor: layout change requested to {target_layout}, " - f"but not supported. Ignoring layout." - ) - - # Handle device transfer - current_device = qt._qdata.device - if target_device is not None: - # Normalize device for comparison - if isinstance(target_device, str): - target_device = torch.device(target_device) - if isinstance(current_device, str): - current_device = torch.device(current_device) - - if target_device != current_device: - logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - if target_dtype is not None: - new_params["orig_dtype"] = target_dtype - new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) - logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") - return new_qt - - logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") - return qt - - -@register_generic_util(torch.ops.aten.detach.default) -def generic_detach(func, args, kwargs): - """Detach operation - creates a detached copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.detach()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.clone.default) -def generic_clone(func, args, kwargs): - """Clone operation - creates a deep copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.clone()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._to_copy.default) -def generic_to_copy(func, args, kwargs): - """Device/dtype transfer operation - handles .to(device) calls.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - op_name="_to_copy" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype_layout) -def generic_to_dtype_layout(func, args, kwargs): - """Handle .to(device) calls using the dtype_layout variant.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - target_layout=kwargs.get('layout', None), - op_name="to" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.copy_.default) -def generic_copy_(func, args, kwargs): - qt_dest = args[0] - src = args[1] - non_blocking = args[2] if len(args) > 2 else False - if isinstance(qt_dest, QuantizedTensor): - if isinstance(src, QuantizedTensor): - # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) - qt_dest._layout_type = src._layout_type - orig_dtype = qt_dest._layout_params["orig_dtype"] - _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) - qt_dest._layout_params["orig_dtype"] = orig_dtype - else: - # Copy from regular tensor - just copy raw data - qt_dest._qdata.copy_(src) - return qt_dest - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) -def generic_has_compatible_shallow_copy_type(func, args, kwargs): - return True - - -@register_generic_util(torch.ops.aten.empty_like.default) -def generic_empty_like(func, args, kwargs): - """Empty_like operation - creates an empty tensor with the same quantized structure.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - # Create empty tensor with same shape and dtype as the quantized data - hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) - new_qdata = torch.empty_like(qt._qdata, **kwargs) - - # Handle device transfer for layout params - target_device = kwargs.get('device', new_qdata.device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - - # Update orig_dtype if dtype is specified - new_params['orig_dtype'] = hp_dtype - - return QuantizedTensor(new_qdata, qt._layout_type, new_params) - return func(*args, **kwargs) - -# ============================================================================== -# FP8 Layout + Operation Handlers -# ============================================================================== -class TensorCoreFP8Layout(QuantizedLayout): - """ - Storage format: - - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) - - scale: Scalar tensor (float32) for dequantization - - orig_dtype: Original dtype before quantization (for casting back) - """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if cls.FP8_DTYPE is None: + raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small tensor_info = torch.finfo(tensor.dtype) scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if scale is not None: - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is None: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) else: tensor = tensor * (1.0 / scale).to(tensor.dtype) + qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) else: - scale = torch.ones((), device=tensor.device, dtype=torch.float32) + qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) - if stochastic_rounding > 0: - tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) - else: - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) - tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) + return qdata, params - layout_params = { - 'scale': scale, - 'orig_dtype': orig_dtype - } - return tensor, layout_params - @staticmethod - def dequantize(qdata, scale, orig_dtype, **kwargs): - plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - plain_tensor.mul_(scale) - return plain_tensor +class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e4m3fn - @classmethod - def get_plain_tensors(cls, qtensor): - return qtensor._qdata, qtensor._layout_params['scale'] + +class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e5m2 + + +# Backward compatibility alias - default to E4M3 +TensorCoreFP8Layout = TensorCoreFP8E4M3Layout + + +# ============================================================================== +# Registry +# ============================================================================== + +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, - "comfy_tensor_layout": "TensorCoreFP8Layout", + "comfy_tensor_layout": "TensorCoreFP8E4M3Layout", + }, + "float8_e5m2": { + "storage_t": torch.float8_e5m2, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8E5M2Layout", + }, + "nvfp4": { + "storage_t": torch.uint8, + "parameters": {"weight_scale", "weight_scale_2", "input_scale"}, + "comfy_tensor_layout": "TensorCoreNVFP4Layout", + "group_size": 16, }, } -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, -} +# ============================================================================== +# Re-exports for backward compatibility +# ============================================================================== -@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") -def fp8_linear(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - weight_t = plain_weight.t() - - tensor_2d = False - if len(plain_input.shape) == 2: - tensor_2d = True - plain_input = plain_input.unsqueeze(1) - - input_shape = plain_input.shape - if len(input_shape) != 3: - return None - - try: - output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]).contiguous(), - weight_t, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - - if not tensor_2d: - output = output.reshape((-1, input_shape[1], weight.shape[0])) - - if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = scale_a * scale_b - output_params = { - 'scale': output_scale, - 'orig_dtype': input_tensor._layout_params['orig_dtype'] - } - return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) - else: - return output - - except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - - # Case 2: DQ Fallback - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - if isinstance(input_tensor, QuantizedTensor): - input_tensor = input_tensor.dequantize() - - return torch.nn.functional.linear(input_tensor, weight, bias) - -def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output - -@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") -def fp8_addmm(func, args, kwargs): - input_tensor = args[1] - weight = args[2] - bias = args[0] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - if isinstance(args[2], QuantizedTensor): - a[2] = args[2].dequantize() - - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") -def fp8_mm(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") -@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") -def fp8_func(func, args, kwargs): - input_tensor = args[0] - if isinstance(input_tensor, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - ar = list(args) - ar[0] = plain_input - return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) - return func(*args, **kwargs) +__all__ = [ + "QuantizedTensor", + "QuantizedLayout", + "TensorCoreFP8Layout", + "TensorCoreFP8E4M3Layout", + "TensorCoreFP8E5M2Layout", + "TensorCoreNVFP4Layout", + "QUANT_ALGOS", + "register_layout_op", +] diff --git a/requirements.txt b/requirements.txt index 3a05799eb..0ee152032 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 +comfy-kitchen>=0.2.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 3a54941e6..7b2eac940 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify scales were loaded - self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) - self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) + self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py deleted file mode 100644 index 9cb54ede8..000000000 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -def has_gpu(): - return torch.cuda.is_available() - -from comfy.cli_args import args -if not has_gpu(): - args.cpu = True - -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout - - -class TestQuantizedTensor(unittest.TestCase): - """Test the QuantizedTensor subclass with FP8 layout""" - - def test_creation(self): - """Test creating a QuantizedTensor with TensorCoreFP8Layout""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._layout_params['scale'], scale) - self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") - - def test_dequantize(self): - """Test explicit dequantization""" - - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - dequantized = qt.dequantize() - - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_from_float(self): - """Test creating QuantizedTensor from float tensor""" - float_tensor = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - - qt = QuantizedTensor.from_float( - float_tensor, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt.shape, (64, 32)) - - # Verify dequantization gives approximately original values - dequantized = qt.dequantize() - mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.1) - - -class TestGenericUtilities(unittest.TestCase): - """Test generic utility operations""" - - def test_detach(self): - """Test detach operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Detach should return a new QuantizedTensor - qt_detached = qt.detach() - - self.assertIsInstance(qt_detached, QuantizedTensor) - self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") - - def test_clone(self): - """Test clone operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Clone should return a new QuantizedTensor - qt_cloned = qt.clone() - - self.assertIsInstance(qt_cloned, QuantizedTensor) - self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") - - # Verify it's a deep copy - self.assertIsNot(qt_cloned._qdata, qt._qdata) - - @unittest.skipUnless(has_gpu(), "GPU not available") - def test_to_device(self): - """Test device transfer""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Moving to same device should work (CPU to CPU) - qt_cpu = qt.to('cpu') - - self.assertIsInstance(qt_cpu, QuantizedTensor) - self.assertEqual(qt_cpu.device.type, 'cpu') - self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') - - -class TestTensorCoreFP8Layout(unittest.TestCase): - """Test the TensorCoreFP8Layout implementation""" - - def test_quantize(self): - """Test quantization method""" - float_tensor = torch.randn(32, 64, dtype=torch.float32) - scale = torch.tensor(1.5) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertEqual(qdata.dtype, torch.float8_e4m3fn) - self.assertEqual(qdata.shape, float_tensor.shape) - self.assertIn('scale', layout_params) - self.assertIn('orig_dtype', layout_params) - self.assertEqual(layout_params['orig_dtype'], torch.float32) - - def test_dequantize(self): - """Test dequantization method""" - float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 - scale = torch.tensor(1.0) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - - # Should approximately match original - self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_q = QuantizedTensor.from_float( - a_fp32, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, QuantizedTensor) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -if __name__ == "__main__": - unittest.main() From 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:50:35 -0800 Subject: [PATCH 131/148] Use rope functions from comfy kitchen. (#11647) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From 161800241117fae7af90e0c938d0cf8cb2f2ddb1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 20:07:39 -0800 Subject: [PATCH 132/148] Revert "Use rope functions from comfy kitchen. (#11647)" (#11648) This reverts commit 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd. --- comfy/ldm/flux/math.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..6a22df8bc 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,7 +4,6 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) -try: - import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 -except: - logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From e14f3b661069971163ddc56036b0f486933b9162 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 6 Jan 2026 14:37:11 +0800 Subject: [PATCH 133/148] chore: update workflow templates to v0.7.66 (#11652) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ee152032..9c9c0e29e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.65 +comfyui-workflow-templates==0.7.66 comfyui-embedded-docs==0.3.1 torch torchsde From 96e0d0924e027248733bc6e0b8102dcdc8acde33 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:43:24 -0800 Subject: [PATCH 134/148] Add helpful message to portable. (#11671) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 2 +- .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 2 +- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause From 6ffc159bdd56d1ad73e954081def6a7f163e7a7f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:53:43 -0800 Subject: [PATCH 135/148] Update comfy-kitchen version to 0.2.1 (#11672) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c9c0e29e..22cb50e2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.0 +comfy-kitchen>=0.2.1 #non essential dependencies: kornia>=0.7.1 From c3c3e93c5bb3034175c17ef8beeb8fe8626c66ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 13:57:50 -0800 Subject: [PATCH 136/148] Use rope functions from comfy kitchen. (#11674) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/requirements.txt b/requirements.txt index 22cb50e2d..7798cb179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.1 +comfy-kitchen>=0.2.2 #non essential dependencies: kornia>=0.7.1 From c3566c0d765200068d26d0888f035504a50012f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 7 Jan 2026 06:28:29 +0800 Subject: [PATCH 137/148] chore: update workflow templates to v0.7.67 (#11667) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7798cb179..caad0026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.66 +comfyui-workflow-templates==0.7.67 comfyui-embedded-docs==0.3.1 torch torchsde From 023cf13721cac256c323e2226319b766d07b1f36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:33:03 -0800 Subject: [PATCH 138/148] Fix lowvram issue with ltxv2 text encoder. (#11675) --- comfy/ldm/lightricks/embeddings_connector.py | 2 +- comfy/text_encoders/lt.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index f7a43f3c3..06f5ada89 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -276,7 +276,7 @@ class Embeddings1DConnector(nn.Module): max(1024, hidden_states.shape[1]) / self.num_learnable_registers ) learnable_registers = torch.tile( - self.learnable_registers, (num_registers_duplications, 1) + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) ) hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 2c2d453e8..e5964e42b 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -86,17 +86,19 @@ class LTXAVTEModel(torch.nn.Module): ) def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) def reset_clip_options(self): self.gemma3_12b.reset_clip_options() + self.execution_device = None def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device - out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) From 6e9ee55cdd9e0eca6b5144063575b983f3311762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:41:27 -0800 Subject: [PATCH 139/148] Disable ltxav previews. (#11676) --- comfy/latent_formats.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9bbe30b53..cb4f52ce1 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -408,7 +408,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] class LTXAV(LTXV): - pass + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None class HunyuanVideo(LatentFormat): latent_channels = 16 From 2c03884f5fb7fa213161dfe1e9a09a8e8c4b6062 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:07:26 -0800 Subject: [PATCH 140/148] Skip fp4 matrix mult on devices that don't support it. (#11677) --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 22f4de044..928282092 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index f5e1e9230..8f9fdce36 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -493,11 +493,12 @@ from .quant_ops import ( ) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None @@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd @@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and From edee33f55ea27a1931475d3ea788fd6e9a81677b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 19:13:43 -0800 Subject: [PATCH 141/148] Disable comfy kitchen cuda if pytorch cuda less than 13 (#11681) --- comfy/quant_ops.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd737726f..5a17bc6f5 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -13,6 +13,13 @@ try: get_layout_class, ) _CK_AVAILABLE = True + if torch.version.cuda is None: + ck.registry.disable("cuda") + else: + cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) + if cuda_version < (13,): + ck.registry.disable("cuda") + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") From c5cfb34c07048350f472a9a4f1ccbf75a56ed38f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 20:51:45 -0800 Subject: [PATCH 142/148] Update comfy-kitchen version to 0.2.3 (#11685) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index caad0026a..bc8346bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.2 +comfy-kitchen>=0.2.3 #non essential dependencies: kornia>=0.7.1 From ce0000c4f2a7dba12324585dddb784b43e3cd3d0 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Tue, 6 Jan 2026 21:57:31 -0800 Subject: [PATCH 143/148] Force sequential execution in CI test jobs (#11687) Added max-parallel setting to enforce sequential execution in test jobs. --- .github/workflows/test-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index adfc5dd32..63df2dc3a 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,6 +20,7 @@ jobs: test-stable: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -74,6 +75,7 @@ jobs: test-unix-nightly: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux] From 79e94544bd7ec0cc7a4e5e6167907e7d781c4b76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 7 Jan 2026 08:04:50 +0200 Subject: [PATCH 144/148] feat(api-nodes): add WAN2.6 ReferenceToVideo (#11644) --- comfy_api_nodes/nodes_wan.py | 160 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 2 +- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 1675fd863..3e04786a9 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,7 +13,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_video_to_comfyapi, validate_audio_duration, + validate_video_duration, ) @@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel): audio_url: str | None = Field(None) +class Reference2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + reference_video_urls: list[str] = Field(...) + + class Txt2ImageParametersField(BaseModel): size: str = Field(...) n: int = Field(1, description="Number of images to generate.") # we support only value=1 @@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel): shot_type: str = Field("single") +class Reference2VideoParametersField(BaseModel): + size: str = Field(...) + duration: int = Field(5, ge=5, le=15) + shot_type: str = Field("single") + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(False) + + class Text2ImageTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2ImageInputField = Field(...) @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel): parameters: Image2VideoParametersField = Field(...) +class Reference2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Reference2VideoInputField = Field(...) + parameters: Reference2VideoParametersField = Field(...) + + class TaskCreationOutputField(BaseModel): task_id: str = Field(...) task_status: str = Field(...) @@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class WanReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanReferenceVideoApi", + display_name="Wan Reference to Video", + category="api node/video/Wan", + description="Use the character and voice from input videos, combined with a prompt, " + "to generate a new video that maintains character consistency.", + inputs=[ + IO.Combo.Input("model", options=["wan2.6-r2v"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese. " + "Use identifiers such as `character1` and `character2` to refer to the reference characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["character1", "character2", "character3"], + min=1, + ), + ), + IO.Combo.Input( + "size", + options=[ + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str, + reference_videos: IO.Autogrow.Type, + size: str, + duration: int, + seed: int, + shot_type: str, + watermark: bool, + ): + reference_video_urls = [] + for i in reference_videos: + validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) + for i in reference_videos: + reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + width, height = RES_IN_PARENS.search(size).groups() + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Reference2VideoTaskCreationRequest( + model=model, + input=Reference2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls + ), + parameters=Reference2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + shot_type=shot_type, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=6, + max_poll_attempts=280, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension): WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, + WanReferenceVideoApi, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b8d33f4d1..f1ed7fe9c 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -119,7 +119,7 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" + filename = f"{uuid.uuid4()}.{container.value.lower()}" # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() From b7d7cc1d496afe3c82279eec74c4d47399aab8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:39:06 -0800 Subject: [PATCH 145/148] Fix fp8 fast issue. (#11688) --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8f9fdce36..cd536e22d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -427,12 +427,12 @@ def fp8_linear(self, input): input = torch.clamp(input, min=-448, max=448, out=input) input_fp8 = input.to(dtype).contiguous() layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) - quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) + quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) From fc0cb10bcbee6e73ed3caf34c27f7bde4559a07f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 04:07:31 -0500 Subject: [PATCH 146/148] ComfyUI v0.8.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1ed60fe5c..750673f08 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/pyproject.toml b/pyproject.toml index a7d159be9..951c2c978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.7.0" +version = "0.8.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From c0c9720d77774ed2c87981da87189fe1c14a57fa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:48:28 -0800 Subject: [PATCH 147/148] Fix stable release workflow not pulling latest comfy kitchen. (#11695) --- .github/workflows/stable-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 28484a9d1..f501b7b31 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -117,7 +117,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt From 3cd7b32f1b7e7e90395cefe7d9f9b1f89276d8ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:15:14 -0800 Subject: [PATCH 148/148] Support gemma 12B with quant weights. (#11696) --- comfy/text_encoders/lt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e5964e42b..130ebaeae 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -36,10 +36,10 @@ class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class Gemma3_12BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -119,12 +119,12 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) -def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)