From 5b4d0664c87dc62a8361fe292b0bdac42681aef8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:02:49 +0200 Subject: [PATCH 001/181] add Flux2MaxImage API Node (#11420) --- comfy_api_nodes/nodes_bfl.py | 68 ++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 8826dea0c..ce077d6b3 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,10 +1,8 @@ -from inspect import cleandoc - import torch from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -28,7 +26,7 @@ from comfy_api_nodes.util import ( ) -def convert_mask_to_image(mask: torch.Tensor): +def convert_mask_to_image(mask: Input.Image): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ @@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor): class FluxProUltraImageNode(IO.ComfyNode): - """ - Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode): node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( "prompt", @@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: torch.Tensor | None = None, + image_prompt: Input.Image | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode): - """ - Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode): node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: torch.Tensor | None = None, + input_image: Input.Image | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextMaxImageNode(FluxKontextProImageNode): - """ - Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. - """ - DESCRIPTION = cleandoc(__doc__ or "") + DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" class FluxProExpandNode(IO.ComfyNode): - """ - Outpaints image based on prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode): node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, prompt_upsampling: bool, top: int, @@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode): - """ - Inpaints image based on mask and prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode): node_id="FluxProFillNode", display_name="Flux.1 Fill Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), IO.Mask.Input("mask"), @@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - mask: torch.Tensor, + image: Input.Image, + mask: Input.Image, prompt: str, prompt_upsampling: bool, steps: int, @@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode): + NODE_ID = "Flux2ProImageNode" + DISPLAY_NAME = "Flux.2 [pro] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="Flux2ProImageNode", - display_name="Flux.2 [pro] Image", + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, category="api node/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ @@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode): ), IO.Boolean.Input( "prompt_upsampling", - default=False, + default=True, tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", + "If active, automatically modifies the prompt for more creative generation.", ), - IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode): height: int, seed: int, prompt_upsampling: bool, - images: torch.Tensor | None = None, + images: Input.Image | None = None, ) -> IO.NodeOutput: reference_images = {} if images is not None: @@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode): reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) initial_response = await sync_op( cls, - ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), response_model=BFLFluxProGenerateResponse, data=Flux2ProGenerateRequest( prompt=prompt, @@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2MaxImageNode(Flux2ProImageNode): + + NODE_ID = "Flux2MaxImageNode" + DISPLAY_NAME = "Flux.2 [max] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension): FluxProExpandNode, FluxProFillNode, Flux2ProImageNode, + Flux2MaxImageNode, ] From 8376ff6831b145eadc3339e1901ffe02386ab86a Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 20 Dec 2025 03:41:56 +0900 Subject: [PATCH 002/181] bump comfyui_manager version to the 4.0.3b7 (#11422) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 5ef0d3a1d..2300f0c70 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b5 +comfyui_manager==4.0.3b7 From cc4ddba1b68abdc64ef5a701fd0571fcf2faf98d Mon Sep 17 00:00:00 2001 From: BradPepersAMD Date: Fri, 19 Dec 2025 15:01:50 -0700 Subject: [PATCH 003/181] Allow enabling use of MIOpen by setting COMFYUI_ENABLE_MIOPEN=1 as an env var (#11366) --- comfy/model_management.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..1889ab0ac 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,7 @@ import importlib import platform import weakref import gc +import os class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -333,13 +334,15 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] +AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) From 809ce687493db84f6743639adf9b600753b6188e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:59:25 -0800 Subject: [PATCH 004/181] Support nested tensor denoise masks. (#11431) --- comfy/samplers.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8340d376c..1989ef107 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -984,9 +984,6 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) @@ -1013,6 +1010,24 @@ class CFGGuider: else: latent_shapes = [latent_image.shape] + if denoise_mask is not None: + if denoise_mask.is_nested: + denoise_masks = denoise_mask.unbind() + denoise_masks = denoise_masks[:len(latent_shapes)] + else: + denoise_masks = [denoise_mask] + + for i in range(len(denoise_masks), len(latent_shapes)): + denoise_masks.append(torch.ones(latent_shapes[i])) + + for i in range(len(denoise_masks)): + denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + + if len(denoise_masks) > 1: + denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + else: + denoise_mask = denoise_masks[0] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) From 514c24d756997c3131c57aa21578a09429096eca Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:22:45 -0700 Subject: [PATCH 005/181] Fix error from logging line (#11423) Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2979b3ca1..1e0f86026 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC): # if multiple conds, split based on primary region if self.split_conds_to_windows and len(cond_in) > 1: region = window.get_region_index(len(cond_in)) - logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}") cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: From 0aa7fa464efc4ecc35a145048c06d325c75fbf2b Mon Sep 17 00:00:00 2001 From: woctordho Date: Sat, 20 Dec 2025 13:16:46 +0800 Subject: [PATCH 006/181] Implement sliding attention in Gemma3 (#11409) --- comfy/text_encoders/llama.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 0d07ac8c6..ed29e014d 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,7 +3,6 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any import math -import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -177,7 +176,7 @@ class Gemma3_4B_Config: num_key_value_heads: int = 4 max_position_embeddings: int = 131072 rms_norm_eps: float = 1e-6 - rope_theta = [10000.0, 1000000.0] + rope_theta = [1000000.0, 10000.0] transformer_type: str = "gemma3" head_dim = 256 rms_norm_add = True @@ -186,8 +185,8 @@ class Gemma3_4B_Config: rope_dims = None q_norm = "gemma3" k_norm = "gemma3" - sliding_attention = [False, False, False, False, False, 1024] - rope_scale = [1.0, 8.0] + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] final_norm: bool = True class RMSNorm(nn.Module): @@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: self.sliding_attention = False @@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module): if self.transformer_type == 'gemma3': if self.sliding_attention: if x.shape[1] > self.sliding_attention: - logging.warning("Warning: sliding attention not implemented, results may be incorrect") + sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + if attention_mask is not None: + attention_mask = attention_mask + sliding_mask + else: + attention_mask = sliding_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] From 3ab9748903a8ee51f62ae8d3eeebc1f98847f4bd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:19:47 -0800 Subject: [PATCH 007/181] Disable prompt weights on newbie te. (#11434) --- comfy/sd1_clip.py | 6 ++++-- comfy/text_encoders/lumina2.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 962948dae..c512ca5d0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -513,6 +513,8 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key + self.disable_weights = disable_weights + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. @@ -547,7 +549,7 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - if kwargs.get("disable_weights", False): + if kwargs.get("disable_weights", self.disable_weights): parsed_weights = [(text, 1.0)] else: parsed_weights = token_weights(text, 1.0) diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 7a6cfdab2..f82883ba1 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} From 767ee30f217e72797df6b018417234bf8b3f7b69 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:22:17 -0800 Subject: [PATCH 008/181] ZImageFunControlNet: Fix mask concatenation in --gpu-only (#11421) This operation trades in latents which in --gpu-only may be out of the GPU The two VAE results will follow the --gpu-only defined behaviour so follow the inpaint image device when calculating the mask in this path. --- comfy_extras/nodes_model_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 2a0cfcf18..1355b3c93 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -348,7 +348,7 @@ class ZImageControlPatch: if self.mask is None: mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] else: - mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") if latent_image is None: latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) From 31e961736a476851e2579d5d9202ed4177a71720 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:23:51 -0800 Subject: [PATCH 009/181] Fix issue with batches and newbie. (#11435) --- comfy/ldm/lumina/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 5628e2ba3..e80b1c138 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -625,7 +625,7 @@ class NextDiT(nn.Module): if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) From 4c432c11ed6f83466b8ff02569872925753a3c44 Mon Sep 17 00:00:00 2001 From: woctordho Date: Sat, 20 Dec 2025 13:57:22 +0800 Subject: [PATCH 010/181] Implement Jina CLIP v2 and NewBie dual CLIP (#11415) * Implement Jina CLIP v2 * Support quantized Gemma in NewBie dual CLIP --- comfy/model_base.py | 2 +- comfy/model_detection.py | 3 +- comfy/sd.py | 20 +++ comfy/text_encoders/jina_clip_2.py | 219 +++++++++++++++++++++++++++++ comfy/text_encoders/newbie.py | 62 ++++++++ nodes.py | 4 +- 6 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 comfy/text_encoders/jina_clip_2.py create mode 100644 comfy/text_encoders/newbie.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..c4f3c0639 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,7 +1110,7 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 84fd409fd..539e296ed 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 diff --git a/comfy/sd.py b/comfy/sd.py index c2a9728f3..7de7dd9c6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie import comfy.model_patcher import comfy.lora @@ -1008,6 +1010,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1038,6 +1041,7 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + JINA_CLIP_2 = 18 def detect_te_model(sd): @@ -1047,6 +1051,8 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] if weight.shape[-1] == 4096: @@ -1207,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..31904462b --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/nodes.py b/nodes.py index b13ceb578..7d83ecb21 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -980,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From fb478f679a2998c4f2e955bcb895cc4c55f119a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Dec 2025 22:02:43 -0800 Subject: [PATCH 011/181] Only apply gemma quant config to gemma model for newbie. (#11436) --- comfy/text_encoders/lumina2.py | 5 +++++ comfy/text_encoders/newbie.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index f82883ba1..b29a7cc87 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py index 31904462b..db2324576 100644 --- a/comfy/text_encoders/newbie.py +++ b/comfy/text_encoders/newbie.py @@ -57,6 +57,6 @@ def te(dtype_llama=None, llama_quantization_metadata=None): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["quantization_metadata"] = llama_quantization_metadata + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) return NewBieTEModel_ From 0899012ad60db23cbc5990d164fbd22195bafcb2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 20 Dec 2025 08:24:37 +0200 Subject: [PATCH 012/181] chore(api-nodes): by default set Watermark generation to False (#11437) --- comfy_api_nodes/apis/bytedance_api.py | 6 +++--- comfy_api_nodes/nodes_bytedance.py | 16 ++++++++-------- comfy_api_nodes/nodes_wan.py | 24 ++++++++++++------------ 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index 77cd76f9b..b8c2f618b 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel): size: str | None = Field(None) seed: int | None = Field(0, ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Image2ImageTaskCreationRequest(BaseModel): @@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel): size: str | None = Field("adaptive") seed: int | None = Field(..., ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Seedream4Options(BaseModel): @@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel): seed: int = Field(..., ge=0, le=2147483647) sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) + watermark: bool = Field(False) class ImageTaskCreationResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 57c0218d0..636cc1265 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), @@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation: str = "disabled", max_images: int = 1, seed: int = 0, - watermark: bool = True, + watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 17b680e13..1675fd863 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel): n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) class Image2ImageParametersField(BaseModel): size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) - watermark: bool = Field(True) + watermark: bool = Field(False) class Text2VideoParametersField(BaseModel): @@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode): height: int = 1024, seed: int = 0, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, ): initial_response = await sync_op( cls, @@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode): # width: int = 1024, # height: int = 1024, seed: int = 0, - watermark: bool = True, + watermark: bool = False, ): n_images = get_number_of_images(image) if n_images not in (1, 2): @@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if "480p" in size and model == "wan2.6-t2v": @@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if get_number_of_images(image) != 1: From bbb11e26081977474eec72ce36d12ec778b5a9ea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 20 Dec 2025 18:48:28 +0200 Subject: [PATCH 013/181] fix(api-nodes): Topaz 4k video upscaling (#11438) --- comfy_api_nodes/nodes_topaz.py | 35 +++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index f522756e5..b04575ad8 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = { "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", } -UPSCALER_VALUES_MAP = { - "FullHD (1080p)": 1920, - "4K (2160p)": 3840, -} class TopazImageEnhance(IO.ComfyNode): @@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", options=["low", "middle", "high"], @@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode): target_frame_rate = src_frame_rate filters = [] if upscaler_enabled: - target_width = UPSCALER_VALUES_MAP[upscaler_resolution] - target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + if "1080p" in upscaler_resolution: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 filters.append( topaz_api.VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], From 807538fe6c66bca8c91edbad14414fb4e109cbde Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 20 Dec 2025 17:02:02 -0800 Subject: [PATCH 014/181] Core release process. (#11447) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index bae955b1b..b0f62695b 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. + - Minor versions will be used for releases off the master branch. + - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release From 91bf6b6aa3d5134c1569375a34ff483d3e32e03f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 21 Dec 2025 16:59:40 -0800 Subject: [PATCH 015/181] Add node to create empty latents for qwen image layered model. (#11460) --- comfy_extras/nodes_qwen.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..fde8fac9a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,7 +3,9 @@ import comfy.utils import math from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import nodes class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class EmptyQwenImageLayeredLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyQwenImageLayeredLatentImage", + display_name="Empty Qwen Image Layered Latent", + category="latent/qwen", + inputs=[ + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + EmptyQwenImageLayeredLatentImage, ] From c176b214cc768d41892add4d4f51c5c5627cbf7b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 22 Dec 2025 08:44:49 +0200 Subject: [PATCH 016/181] extend possible duration range for Kling O1 StartEndFrame node (#11451) --- comfy_api_nodes/nodes_kling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 1a6364fa0..5294b10d4 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -858,7 +858,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): tooltip="A text prompt describing the video content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("duration", options=["5", "10"]), + IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", @@ -897,6 +897,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + if duration not in (5, 10) and end_frame is None and reference_images is None: + raise ValueError( + "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." + ) validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ From eb0e10aec449eed2bbcda82ae5b56070e61ed86f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 23 Dec 2025 05:02:41 +0800 Subject: [PATCH 017/181] Update workflow templates to v0.7.62 (#11467) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54696395f..b41dbe1d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.60 +comfyui-workflow-templates==0.7.62 comfyui-embedded-docs==0.3.1 torch torchsde From 33aa808713f7c36cd9476c53b8b67c745e9bc107 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 22 Dec 2025 13:43:24 -0800 Subject: [PATCH 018/181] Make denoised output on custom sampler nodes work with nested tensors. (#11471) --- comfy_extras/nodes_custom_sampler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac1..993889d9d 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -760,8 +760,12 @@ class SamplerCustom(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -948,8 +952,12 @@ class SamplerCustomAdvanced(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) From f4f44bb8073d02597aca61193fec6143292a0b88 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:10:27 +0200 Subject: [PATCH 019/181] api-nodes: use new custom endpoint for Nano Banana (#11311) --- comfy_api_nodes/apis/gemini_api.py | 1 + comfy_api_nodes/nodes_gemini.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c9..d81337dae 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..e8ed7e797 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): From 22ff1bbfcb532a294b200a90270b772a339d334e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 24 Dec 2025 09:48:45 +0800 Subject: [PATCH 020/181] chore: update workflow templates to v0.7.63 (#11482) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b41dbe1d7..59ac599c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.62 +comfyui-workflow-templates==0.7.63 comfyui-embedded-docs==0.3.1 torch torchsde From e4c61d75555036fa28b6bb34e5fd67b007c9f391 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 Dec 2025 20:50:02 -0500 Subject: [PATCH 021/181] ComfyUI v0.6.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b45309198..1f28e2407 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.1" +__version__ = "0.6.0" diff --git a/pyproject.toml b/pyproject.toml index 3a6960811..35a268bd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.5.1" +version = "0.6.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 650e716dda0a966a083f0efe299f3e83336f920e Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 24 Dec 2025 14:29:41 +0900 Subject: [PATCH 022/181] Bump comfyui-frontend-package to 1.35.9 (#11470) Co-authored-by: github-actions[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 59ac599c1..84b1882aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.34.9 +comfyui-frontend-package==1.35.9 comfyui-workflow-templates==0.7.63 comfyui-embedded-docs==0.3.1 torch From 4f067b07fb33cc1b61d91aec73ca968ba7d9c29a Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 25 Dec 2025 07:54:21 +0800 Subject: [PATCH 023/181] chore: update workflow templates to v0.7.64 (#11496) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 84b1882aa..8b670b813 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.63 +comfyui-workflow-templates==0.7.64 comfyui-embedded-docs==0.3.1 torch torchsde From 532e2850794c7b497174a0a42ac0cb1fe5b62499 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 24 Dec 2025 16:09:37 -0800 Subject: [PATCH 024/181] Add a ManualSigmas node. (#11499) Can be used to manually set the sigmas for a model. This node accepts a list of integer and floating point numbers separated with any non numeric character. --- comfy_extras/nodes_custom_sampler.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 993889d9d..f19adf4b9 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import comfy.utils import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -1013,6 +1014,25 @@ class AddNoise(io.ComfyNode): add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1052,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension): DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] From d9a76cf66e3fc6b0047692a07bc1d24f20e16e20 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:46:51 -0800 Subject: [PATCH 025/181] Specify in readme that we only support pytorch 2.4 and up. (#11512) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b0f62695b..6d09758c0 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. From 16fb6849d296259fd2bf106a6f894650d9a12072 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 27 Dec 2025 08:55:59 +0900 Subject: [PATCH 026/181] bump comfyui_manager version to the 4.0.4 (#11521) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 2300f0c70..6585b0c19 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b7 +comfyui_manager==4.0.4 From 1e4e342f54386ea4179b273c24b37bd8cbde8f37 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:03:01 -0800 Subject: [PATCH 027/181] Fix noise with ancestral samplers when inferencing on cpu. (#11528) --- comfy/k_diffusion/sampling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1ba9edad7..0949dee44 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: From 865568b7fc5fd2a5f626b22a40c363b0a5f0b399 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:16:21 +0200 Subject: [PATCH 028/181] feat(api-nodes): add Kling Motion Control node (#11493) --- comfy_api_nodes/apis/kling_api.py | 9 ++++ comfy_api_nodes/nodes_kling.py | 87 +++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 80a758466..bf54ede3e 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel): prompt: str = Field(...) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5294b10d4..58259e029 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -51,6 +51,7 @@ from comfy_api_nodes.apis import ( ) from comfy_api_nodes.apis.kling_api import ( ImageToVideoWithAudioRequest, + MotionControlRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -2163,6 +2164,91 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2188,6 +2274,7 @@ class KlingExtension(ComfyExtension): OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, ] From eff4ea0b625e851d641b8f6532ff7afe2df16b9d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:39:02 +0200 Subject: [PATCH 029/181] [V3] converted nodes_images.py to V3 schema (#11206) * converted nodes_images.py to V3 schema * fix test --- comfy_api/latest/_io.py | 5 +- comfy_api/latest/_util/__init__.py | 2 + comfy_api/latest/_util/image_types.py | 18 + comfy_extras/nodes_images.py | 683 +++++++++--------- .../comfy_extras_test/image_stitch_test.py | 2 +- 5 files changed, 351 insertions(+), 359 deletions(-) create mode 100644 comfy_api/latest/_util/image_types.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4b14e5ded..ba0b95498 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,9 +28,8 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr prune_dict, shallow_clone_class) from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -656,7 +655,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda..6313eb01b 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,5 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SVG", ] diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000..f031ed426 --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c..ce21caade 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,231 @@ from __future__ import annotations import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: + crop = execute # TODO: remove + + +class RepeatImageBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def repeat(self, image, amount): + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) - return (s,) + return IO.NodeOutput(s) -class ImageFromBatch: + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def frombatch(self, image, batch_index, length): + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove -class ImageAddNoise: +class ImageAddNoise(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + repeat = execute # TODO: remove - methods = {"default": 4, "fastest": 0, "slowest": 6} - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } - -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - RETURN_TYPES = () - FUNCTION = "save_images" + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) + save_images = execute # TODO: remove -class ImageStitch: +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + @classmethod + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +363,30 @@ Optional spacing can be added between images. images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): -class ResizeAndPadImage: @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"]), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" - - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +414,47 @@ class ResizeAndPadImage: padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +484,64 @@ class SaveSVGNode: with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" - - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" - - def get_size(self, image, unique_id=None) -> tuple[int, int]: + @classmethod + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) + + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): -class ImageRotate: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +551,57 @@ class ImageRotate: rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) + + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): -class ImageFlip: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, largest_size): + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +618,30 @@ class ImageScaleToMaxDimension: samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) + return IO.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + upscale = execute # TODO: remove + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index b5a0f022c..5c6a15ac4 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -25,7 +25,7 @@ class TestImageStitch: result = node.stitch(image1, "right", True, 0, "white", image2=None) - assert len(result) == 1 + assert len(result.result) == 1 assert torch.equal(result[0], image1) def test_basic_horizontal_stitch_right(self): From 0d2e4bdd44f61b198588c5db99bebfd5cdec286b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:55:30 +0200 Subject: [PATCH 030/181] fix(api-nodes-gemini): always force enhance_prompt to be True (#11503) --- comfy_api_nodes/nodes_veo2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380..13a6bfd91 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, ), IO.Combo.Input( From 36deef2c57eacb5d847bd709c5f3068190630612 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:56:52 +0200 Subject: [PATCH 031/181] chore(api-nodes): switch to credits instead of $ (#11489) --- comfy_api_nodes/util/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5f..f372ec7b5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -430,9 +430,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: From 2943093a5310fc96aa010a3c68c04f7c16f58a9e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 27 Dec 2025 15:54:15 -0800 Subject: [PATCH 032/181] Enable async offload by default for AMD. (#11534) --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1889ab0ac..e5554e225 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1019,8 +1019,8 @@ NUM_STREAMS = 0 if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: From 8fd07170f1b0a7eaaf5a62020cd1926dd3b5092c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 28 Dec 2025 19:07:25 -0800 Subject: [PATCH 033/181] Comment out unused norm_final in lumina/z image model. (#11545) --- comfy/ldm/lumina/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e80b1c138..afbab2ac7 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -491,7 +491,8 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: From 9ca7e143afe6f09734c9aefcc85f491c5c0dc6e0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:19:34 -0800 Subject: [PATCH 034/181] mm: discard async errors from pinning failures (#10738) Pretty much every error cudaHostRegister can throw also queues the same error on the async GPU queue. This was fixed for repinning error case, but there is the bad mmap and just enomem cases that are harder to detect. Do some dummy GPU work to clean the error state. --- comfy/model_management.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e5554e225..9fcb699bc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1126,6 +1126,16 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + torch.cuda.synchronize() + except torch.AcceleratorError: + #Dump it! We already know about it from the synchronous return + pass + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1158,6 +1168,8 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + discard_cuda_async_error() return False @@ -1186,6 +1198,8 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + discard_cuda_async_error() return False From 0e6221cc79a3f3cbf0e15a8321bfe75fcffbe667 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:26:42 -0800 Subject: [PATCH 035/181] Add some warnings for pin and unpin errors. (#11561) --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9fcb699bc..87baedd73 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1169,6 +1169,7 @@ def pin_memory(tensor): TOTAL_PINNED_MEMORY += size return True else: + logging.warning("Pin error.") discard_cuda_async_error() return False @@ -1199,6 +1200,7 @@ def unpin_memory(tensor): TOTAL_PINNED_MEMORY = 0 return True else: + logging.warning("Unpin error.") discard_cuda_async_error() return False From d7111e426a48127a97922227b03d31391eb4eba2 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Tue, 30 Dec 2025 03:07:29 +0200 Subject: [PATCH 036/181] ResizeByLongerSide: support video (#11555) (cherry picked from commit 98c6840aa4e5fd5407ba9ab113d209011e474bf6) --- comfy_extras/nodes_dataset.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 513aecf3a..5ef851bd0 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): From 25a1bfab4e19b541c2bd6f253a3b83886fb660a1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 30 Dec 2025 18:33:34 +0200 Subject: [PATCH 037/181] chore(api-nodes-bytedance): mark "seededit" as deprecated, adjust display name of Seedream (#11490) --- comfy_api_nodes/nodes_bytedance.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 636cc1265..d4a2cfae6 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4", + display_name="ByteDance Seedream 4.5", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ From 178bdc5e14ec0a55e401c509719e33773cc9b565 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:40:42 -0700 Subject: [PATCH 038/181] Add handling for vace_context in context windows (#11386) Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 1e0f86026..2f82d51da 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC): audio_cond = cond_value.cond if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # Handle vace_context (temporal dim is 3) + elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + vace_cond = cond_value.cond + if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim): + sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list) + new_cond_item[cond_key] = cond_value._copy_with(sliced_vace) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ From f59f71cf34067d46713f6243312f7f0b360d061f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 Dec 2025 22:41:22 -0500 Subject: [PATCH 039/181] ComfyUI version v0.7.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1f28e2407..1ed60fe5c 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/pyproject.toml b/pyproject.toml index 35a268bd1..bc1467941 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.6.0" +version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 0357ed7ec4a1bbfe3832874ad6cfc1ca3db1bc0d Mon Sep 17 00:00:00 2001 From: mengqin Date: Tue, 30 Dec 2025 17:53:52 -1000 Subject: [PATCH 040/181] Add support for sage attention 3 in comfyui, enable via new cli arg (#11026) * Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 * Fix some bugs found in PR review. The N dimension at which Sage Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale). * Remove the Sage Attention3 switch, but retain the attention function registration. * Fix a ruff check issue in attention.py --- comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..ccf690945 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,13 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): From 0be8a76c933026011098d41e61cc6e544739e427 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 30 Dec 2025 20:09:55 -0800 Subject: [PATCH 041/181] V3 Improvements + DynamicCombo + Autogrow exposed in public API (#11345) * Support Combo outputs in a more sane way * Remove test validate_inputs function on test node * Make curr_prefix be a list of strings instead of string for easier parsing as keys get added to dynamic types * Start to account for id prefixes from frontend, need to fix bug with nested dynamics * Ensure inputs/outputs/hidden are lists in schema finalize function, remove no longer needed 'is not None' checks * Add raw_link and extra_dict to all relevant Inputs * Make nested DynamicCombos work properly with prefixed keys on latest frontend; breaks old Autogrow, but is pretty much ready for upcoming Autogrow keys * Replace ... usage with a MISSING sentinel for clarity in nodes_logic.py * Added CustomCombo node in backend to reflect frontend node * Prepare Autogrow's expand_schema_for_dynamic to work with upcoming frontend changes * Prepare for look up table for dynamic input stuff * More progress towards dynamic input lookup function stuff * Finished converting _expand_schema_for_dynamic to be done via lookup instead of OOP to guarantee working with process isolation, did refactoring to remove old implementation + cleaning INPUT_TYPES definition including v3 hidden definition * Change order of functions * Removed some unneeded functions after dynamic refactor * Make MatchType's output default displayname "MATCHTYPE" * Fix DynamicSlot get_all * Removed redundant code - dynamic stuff no longer happens in OOP way * Natively support AnyType (*) without __ne__ hacks * Remove stray code that made it in * Remove expand_schema_for_dynamic left over on DynamicInput class * get_dynamic() on DynamicInput/Output was not doing anything anymore, so removed it * Make validate_inputs validate combo input correctly * Temporarily comment out conversion to 'new' (9 month old) COMBO format in get_input_info * Remove refrences to resources feature scrapped from V3 * Expose DynamicCombo in public API * satisfy ruff after some code got commented out * Make missing input error prettier for dynamic types * Created a Switch2 node as a side-by-side test, will likely go with Switch2 as the initial switch node * Figured out Switch situation * Pass in v3_data in IsChangedCache.get function's fingerprint_inputs, add a from_v3_data helper method to HiddenHolder * Switch order of Switch and Soft Switch nodes in file * Temp test node for MatchType * Fix missing v3_data for v1 nodes in validation * For now, remove chacking duplicate id's for dynamic types * Add Resize Image/Mask node that thanks to MatchType+DynamicCombo is 16-nodes-in-1 * Made DynamicCombo references in DCTestNode use public interface * Add an AnyTypeTestNode * Make lazy status for specific inputs on DynamicInputs work by having the values of the dictionary for check_lazy_status be a tuple, where the second element is the key of the input that can be returned * Comment out test logic nodes * Make primitive float's step make more sense * Add (and leave commented out) some potential logic nodes * Change default crop option to "center" on Resize Image/Mask node * Changed copy.copy(d) to d.copy() * Autogrow is available in stable frontend, so exposing it in public API * Use outputs id as display_name if no display_name present, remove v3 outputs id restriction that made them have to have unique IDs from the inputs * Enable Custom Combo node as stable frontend now supports it * Make id properly act like display_name on outputs * Add Batch Images/Masks/Latents node * Comment out Batch Images/Masks/Latents node for now, as Autogrow has a bug with MatchType where top connection is disconnected upon refresh * Removed code for a couple test nodes in nodes_logic.py * Add Batch Images, Batch Masks, and Batch Latents nodes with Autogrow, deprecate old Batch Images + LatentBatch nodes --- comfy_api/latest/__init__.py | 1 - comfy_api/latest/_io.py | 370 ++++++++++++++------------ comfy_api/latest/_resources.py | 72 ----- comfy_execution/graph.py | 5 + comfy_execution/validation.py | 14 + comfy_extras/nodes_latent.py | 1 + comfy_extras/nodes_logic.py | 149 +++++++++-- comfy_extras/nodes_post_processing.py | 356 +++++++++++++++++++++++++ comfy_extras/nodes_primitive.py | 2 +- execution.py | 45 ++-- nodes.py | 1 + 11 files changed, 742 insertions(+), 274 deletions(-) delete mode 100644 comfy_api/latest/_resources.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index fab63c7df..b0fa14ff6 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui -# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ba0b95498..764fa8b2b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker from ._util import MESH, VOXEL, SVG as _SVG @@ -76,16 +75,6 @@ class NumberDisplay(str, Enum): slider = "slider" -class _StringIOType(str): - def __ne__(self, value: object) -> bool: - if self == "*" or value == "*": - return False - if not isinstance(value, str): - return True - a = frozenset(self.split(",")) - b = frozenset(value.split(",")) - return not (b.issubset(a) or a.issubset(b)) - class _ComfyType(ABC): Type = Any io_type: str = None @@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs): new_cls.__module__ = cls.__module__ new_cls.__doc__ = cls.__doc__ # assign ComfyType attributes, if needed - # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) - new_cls.io_type = _StringIOType(io_type) + new_cls.io_type = io_type if hasattr(new_cls, "Input") and new_cls.Input is not None: new_cls.Input.Parent = new_cls if hasattr(new_cls, "Output") and new_cls.Output is not None: @@ -165,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -173,6 +161,7 @@ class Input(_IO_V3): self.tooltip = tooltip self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} + self.rawLink = raw_link def as_dict(self): return prune_dict({ @@ -180,10 +169,11 @@ class Input(_IO_V3): "optional": self.optional, "tooltip": self.tooltip, "lazy": self.lazy, + "rawLink": self.rawLink, }) | prune_dict(self.extra_dict) def get_io_type(self): - return _StringIOType(self.io_type) + return self.io_type def get_all(self) -> list[Input]: return [self] @@ -194,8 +184,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -217,13 +207,14 @@ class Output(_IO_V3): def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id - self.display_name = display_name + self.display_name = display_name if display_name else id self.tooltip = tooltip self.is_output_list = is_output_list def as_dict(self): + display_name = self.display_name if self.display_name else self.id return prune_dict({ - "display_name": self.display_name, + "display_name": display_name, "tooltip": self.tooltip, "is_output_list": self.is_output_list, }) @@ -251,8 +242,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.label_on = label_on self.label_off = label_off self.default: bool @@ -271,8 +262,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -297,8 +288,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -323,8 +314,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -357,12 +348,14 @@ class Combo(ComfyTypeIO): image_folder: FolderType=None, remote: RemoteOptions=None, socketless: bool=None, + extra_dict=None, + raw_link: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -386,10 +379,6 @@ class Combo(ComfyTypeIO): super().__init__(id, display_name, tooltip, is_output_list) self.options = options if options is not None else [] - @property - def io_type(self): - return self.options - @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' @@ -398,8 +387,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + socketless: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -432,9 +421,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) @comfytype(io_type="MASK") @@ -787,7 +776,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -800,7 +789,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self._io_types = types @property @@ -854,8 +843,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.template = template def as_dict(self): @@ -866,6 +855,8 @@ class MatchType(ComfyTypeIO): class Output(Output): def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): + if not id and not display_name: + display_name = "MATCHTYPE" super().__init__(id, display_name, tooltip, is_output_list) self.template = template @@ -878,24 +869,30 @@ class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - def get_dynamic(self) -> list[Input]: - return [] - - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - pass + pass class DynamicOutput(Output, ABC): ''' Abstract class for dynamic output registration. ''' - def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) + pass - def get_dynamic(self) -> list[Output]: - return [] +def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]: + if prefix_list is None: + prefix_list = [] + if id is not None: + prefix_list = prefix_list + [id] + return prefix_list + +def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str: + assert not (prefix_list is None and id is None) + if prefix_list is None: + return id + elif id is not None: + prefix_list = prefix_list + [id] + return ".".join(prefix_list) @comfytype(io_type="COMFY_AUTOGROW_V3") class Autogrow(ComfyTypeI): @@ -932,14 +929,6 @@ class Autogrow(ComfyTypeI): def validate(self): self.input.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - real_inputs = [] - for name, input in self.cached_inputs.items(): - if name in live_inputs: - real_inputs.append(input) - add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, real_inputs, curr_prefix) - class TemplatePrefix(_AutogrowTemplate): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): super().__init__(input) @@ -984,22 +973,45 @@ class Autogrow(ComfyTypeI): "template": self.template.as_dict(), }) - def get_dynamic(self) -> list[Input]: - return self.template.get_all() - def get_all(self) -> list[Input]: return [self] + self.template.get_all() def validate(self): self.template.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - curr_prefix = f"{curr_prefix}{self.id}." - # need to remove self from expected inputs dictionary; replaced by template inputs in frontend - for inner_dict in d.values(): - if self.id in inner_dict: - del inner_dict[self.id] - self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + # NOTE: purposely do not include self in out_dict; instead use only the template inputs + # need to figure out names based on template type + is_names = ("names" in value[1]["template"]) + is_prefix = ("prefix" in value[1]["template"]) + input = value[1]["template"]["input"] + if is_names: + min = value[1]["template"]["min"] + names = value[1]["template"]["names"] + max = len(names) + elif is_prefix: + prefix = value[1]["template"]["prefix"] + min = value[1]["template"]["min"] + max = value[1]["template"]["max"] + names = [f"{prefix}{i}" for i in range(max)] + # need to create a new input based on the contents of input + template_input = None + for _, dict_input in input.items(): + # for now, get just the first value from dict_input + template_input = list(dict_input.values())[0] + new_dict = {} + for i, name in enumerate(names): + expected_id = finalize_prefix(curr_prefix, name) + if expected_id in live_inputs: + # required + if i < min: + type_dict = new_dict.setdefault("required", {}) + # optional + else: + type_dict = new_dict.setdefault("optional", {}) + type_dict[name] = template_input + parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") class DynamicCombo(ComfyTypeI): @@ -1022,23 +1034,6 @@ class DynamicCombo(ComfyTypeI): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - # check if dynamic input's id is in live_inputs - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - key = live_inputs[self.id] - selected_option = None - for option in self.options: - if option.key == key: - selected_option = option - break - if selected_option is not None: - add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) - - def get_dynamic(self) -> list[Input]: - return [input for option in self.options for input in option.inputs] - def get_all(self) -> list[Input]: return [self] + [input for option in self.options for input in option.inputs] @@ -1053,6 +1048,24 @@ class DynamicCombo(ComfyTypeI): for input in option.inputs: input.validate() + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + key = live_inputs[finalized_id] + selected_option = None + # get options from dict + options: list[dict[str, str | dict[str, Any]]] = value[1]["options"] + for option in options: + if option["key"] == key: + selected_option = option + break + if selected_option is not None: + parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + @comfytype(io_type="COMFY_DYNAMICSLOT_V3") class DynamicSlot(ComfyTypeI): Type = dict[str, Any] @@ -1075,17 +1088,8 @@ class DynamicSlot(ComfyTypeI): self.force_input = True self.slot.force_input = True - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) - - def get_dynamic(self) -> list[Input]: - return [self.slot] + self.inputs - def get_all(self) -> list[Input]: - return [self] + [self.slot] + self.inputs + return [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ @@ -1099,17 +1103,41 @@ class DynamicSlot(ComfyTypeI): for input in self.inputs: input.validate() -def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): - dynamic = d.setdefault("dynamic_paths", {}) - if self is not None: - dynamic[self.id] = f"{curr_prefix}{self.id}" - for i in inputs: - if not isinstance(i, DynamicInput): - dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + inputs = value[1]["inputs"] + parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + +DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} +def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): + DYNAMIC_INPUT_LOOKUP[io_type] = func + +def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]: + return DYNAMIC_INPUT_LOOKUP[io_type] + +def setup_dynamic_input_funcs(): + # Autogrow.Input + register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic) + # DynamicCombo.Input + register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) + # DynamicSlot.Input + register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + +if len(DYNAMIC_INPUT_LOOKUP) == 0: + setup_dynamic_input_funcs() class V3Data(TypedDict): hidden_inputs: dict[str, Any] + 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] + 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + create_dynamic_tuple: bool + 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1145,6 +1173,10 @@ class HiddenHolder: api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), ) + @classmethod + def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder: + return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None) + class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. @@ -1250,61 +1282,56 @@ class Schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' nested_inputs: list[Input] = [] - if self.inputs is not None: - for input in self.inputs: + for input in self.inputs: + if not isinstance(input, DynamicInput): nested_inputs.extend(input.get_all()) - input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] - output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_ids = [i.id for i in nested_inputs] + output_ids = [o.id for o in self.outputs] input_set = set(input_ids) output_set = set(output_ids) - issues = [] + issues: list[str] = [] # verify ids are unique per list if len(input_set) != len(input_ids): issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") if len(output_set) != len(output_ids): issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") - # verify ids are unique between lists - intersection = input_set & output_set - if len(intersection) > 0: - issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) # validate inputs and outputs - if self.inputs is not None: - for input in self.inputs: - input.validate() - if self.outputs is not None: - for output in self.outputs: - output.validate() + for input in self.inputs: + input.validate() + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # ensure inputs, outputs, and hidden are lists + if self.inputs is None: + self.inputs = [] + if self.outputs is None: + self.outputs = [] + if self.hidden is None: + self.hidden = [] # if is an api_node, will need key-related hidden if self.is_api_node: - if self.hidden is None: - self.hidden = [] if Hidden.auth_token_comfy_org not in self.hidden: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: - if self.hidden is None: - self.hidden = [] if Hidden.prompt not in self.hidden: self.hidden.append(Hidden.prompt) if Hidden.extra_pnginfo not in self.hidden: self.hidden.append(Hidden.extra_pnginfo) # give outputs without ids default ids - if self.outputs is not None: - for i, output in enumerate(self.outputs): - if output.id is None: - output.id = f"_{i}_{output.io_type}_" + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: - # NOTE: live_inputs will not be used anymore very soon and this will be done another way + def get_v1_info(self, cls) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs, live_inputs) + input = create_input_dict_v1(self.inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1384,33 +1411,54 @@ class Schema: ) return info +def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: + out_dict = { + "required": {}, + "optional": {}, + "dynamic_paths": {}, + } + d = d.copy() + # ignore hidden for parsing + hidden = d.pop("hidden", None) + parse_class_inputs(out_dict, live_inputs, d) + if hidden is not None and include_hidden: + out_dict["hidden"] = hidden + v3_data = {} + dynamic_paths = out_dict.pop("dynamic_paths", None) + if dynamic_paths is not None: + v3_data["dynamic_paths"] = dynamic_paths + return out_dict, hidden, v3_data -def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: +def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: + for input_type, inner_d in curr_dict.items(): + for id, value in inner_d.items(): + io_type = value[0] + if io_type in DYNAMIC_INPUT_LOOKUP: + # dynamic inputs need to be handled with lookup functions + dynamic_input_func = get_dynamic_input_func(io_type) + new_prefix = handle_prefix(curr_prefix, id) + dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix) + else: + # non-dynamic inputs get directly transferred + finalized_id = finalize_prefix(curr_prefix, id) + out_dict[input_type][finalized_id] = value + if curr_prefix: + out_dict["dynamic_paths"][finalized_id] = finalized_id + +def create_input_dict_v1(inputs: list[Input]) -> dict: input = { "required": {} } - add_to_input_dict_v1(input, inputs, live_inputs) + for i in inputs: + add_to_dict_v1(i, input) return input -def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): - for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, d) - if live_inputs is not None: - i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) - else: - add_to_dict_v1(i, d) - -def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): +def add_to_dict_v1(i: Input, d: dict): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - if dynamic_dict is None: - value = (i.get_io_type(), as_dict) - else: - value = (i.get_io_type(), as_dict, dynamic_dict) - d.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) @@ -1422,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): values = values.copy() result = {} + create_tuple = v3_data.get("create_dynamic_tuple", False) + for key, path in paths.items(): parts = path.split(".") current = result @@ -1430,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): is_last = (i == len(parts) - 1) if is_last: - current[p] = values.pop(key, None) + value = values.pop(key, None) + if create_tuple: + value = (value, key) + current[p] = value else: current = current.setdefault(p, {}) @@ -1445,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): SCHEMA = None # filled in during execution - resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -1492,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): return [name for name in kwargs if kwargs[name] is None] def __init__(self): - self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -1560,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) + type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone @final @@ -1677,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + def INPUT_TYPES(cls) -> dict[str, dict]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls, live_inputs) - input = info.input - if not include_hidden: - input.pop("hidden", None) - if return_schema: - v3_data: V3Data = {} - dynamic = input.pop("dynamic_paths", None) - if dynamic is not None: - v3_data["dynamic_paths"] = dynamic - return input, schema, v3_data - return input + info = schema.get_v1_info(cls) + return info.input @final @classmethod @@ -1808,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal): return self.args if len(self.args) > 0 else None @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + def from_dict(cls, data: dict[str, Any]) -> NodeOutput: args = () ui = None expand = None @@ -1903,8 +1945,8 @@ __all__ = [ "Tracks", # Dynamic Types "MatchType", - # "DynamicCombo", - # "Autogrow", + "DynamicCombo", + "Autogrow", # Other classes "HiddenHolder", "Hidden", diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py deleted file mode 100644 index a6bdda972..000000000 --- a/comfy_api/latest/_resources.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations -import comfy.utils -import folder_paths -import logging -from abc import ABC, abstractmethod -from typing import Any -import torch - -class ResourceKey(ABC): - Type = Any - def __init__(self): - ... - -class TorchDictFolderFilename(ResourceKey): - '''Key for requesting a torch file via file_name from a folder category.''' - Type = dict[str, torch.Tensor] - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TorchDictFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get(self, key: ResourceKey, default: Any=...) -> Any: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, Any] = {} - - def get(self, key: ResourceKey, default: Any=...) -> Any: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, TorchDictFolderFilename): - if default is ...: - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - else: - full_path = folder_paths.get_full_path(key.folder_name, key.file_name) - if full_path is not None: - to_return = comfy.utils.load_torch_file(full_path, safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - if default is not ...: - return default - raise Exception(f"Unsupported resource key type: {type(key)}") - - -class _RESOURCES: - ResourceKey = ResourceKey - TorchDictFolderFilename = TorchDictFolderFilename - Resources = Resources - ResourcesLocal = ResourcesLocal diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..8fc5846b7 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -97,6 +97,11 @@ def get_input_info( extra_info = input_info[1] else: extra_info = {} + # if input_type is a list, it is a Combo defined in outdated format; convert it. + # NOTE: uncomment this when we are confident old format going away won't cause too much trouble. + # if isinstance(input_type, list): + # extra_info["options"] = input_type + # input_type = IO.Combo.io_type return input_type, input_category, extra_info class TopologicalSort: diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index 24c0b4ed7..e73624bd1 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -21,14 +21,24 @@ def validate_node_input( """ # If the types are exactly the same, we can return immediately # Use pre-union behaviour: inverse of `__ne__` + # NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class. if not received_type != input_type: return True + # If one of the types is '*', we can return True immediately; this is the 'Any' type. + if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type: + return True + # If the received type or input_type is a MatchType, we can return True immediately; # validation for this is handled by the frontend if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: return True + # This accounts for some custom nodes that output lists of options as the type; + # if we ever want to break them on purpose, this can be removed + if isinstance(received_type, list) and input_type == IO.Combo.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False @@ -37,6 +47,10 @@ def validate_node_input( received_types = set(t.strip() for t in received_type.split(",")) input_types = set(t.strip() for t in input_type.split(",")) + # If any of the types is '*', we can return True immediately; this is the 'Any' type. + if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types: + return True + if strict: # In strict mode, all received types must be in the input types return received_types.issubset(input_types) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2815c5ffc..9ba1c4ba8 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode): return io.Schema( node_id="LatentBatch", category="latent/batch", + is_deprecated=True, inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95a6ba788..eb888316a 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.latest import _io +# sentinel for missing inputs +MISSING = object() class SwitchNode(io.ComfyNode): @@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode): display_name="Switch", category="logic", is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class SoftSwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySoftSwitchNode", + display_name="Soft Switch", + category="logic", + is_experimental=True, inputs=[ io.Boolean.Input("switch"), io.MatchType.Input("on_false", template=template, lazy=True, optional=True), @@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode): ) @classmethod - def check_lazy_status(cls, switch, on_false=..., on_true=...): - # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING): + # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs. # This trick allows us to ignore the value of the switch and still be able to run execute(). # One of the inputs may be missing, in which case we need to evaluate the other input - if on_false is ...: + if on_false is MISSING: return ["on_true"] - if on_true is ...: + if on_true is MISSING: return ["on_false"] # Normal lazy switch operation if switch and on_true is None: @@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode): return ["on_false"] @classmethod - def validate_inputs(cls, switch, on_false=..., on_true=...): + def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING): # This check happens before check_lazy_status(), so we can eliminate the case where # both inputs are missing. - if on_false is ... and on_true is ...: + if on_false is MISSING and on_true is MISSING: return "At least one of on_false or on_true must be connected to Switch node" return True @classmethod - def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: - if on_true is ...: + def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput: + if on_true is MISSING: return io.NodeOutput(on_false) - if on_false is ...: + if on_false is MISSING: return io.NodeOutput(on_true) return io.NodeOutput(on_true if switch else on_false) +class CustomComboNode(io.ComfyNode): + """ + Frontend node that allows user to write their own options for a combo. + This is here to make sure the node has a backend-representation to avoid some annoyances. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CustomCombo", + display_name="Custom Combo", + category="utils", + is_experimental=True, + inputs=[io.Combo.Input("choice", options=[])], + outputs=[io.String.Output()] + ) + + @classmethod + def validate_inputs(cls, choice: io.Combo.Type) -> bool: + # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. + # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. + # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. + return True + + @classmethod + def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(choice) + + class DCTestNode(io.ComfyNode): class DCValues(TypedDict): combo: str @@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode): display_name="DCTest", category="logic", is_output_node=True, - inputs=[_io.DynamicCombo.Input("combo", options=[ - _io.DynamicCombo.Option("option1", [io.String.Input("string")]), - _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), - _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), - _io.DynamicCombo.Option("option4", [ - _io.DynamicCombo.Input("subcombo", options=[ - _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), - _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + inputs=[io.DynamicCombo.Input("combo", options=[ + io.DynamicCombo.Option("option1", [io.String.Input("string")]), + io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + io.DynamicCombo.Option("option4", [ + io.DynamicCombo.Input("subcombo", options=[ + io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), ]) ])] )], @@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode): combined = ",".join([str(x) for x in vals]) return io.NodeOutput(combined) +class ComboOutputTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComboOptionTestNode", + display_name="ComboOptionTest", + category="logic", + inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), + io.Combo.Input("combo2", options=["option4", "option5", "option6"])], + outputs=[io.Combo.Output(), io.Combo.Output()], + ) + + @classmethod + def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(combo, combo2) + +class ConvertStringToComboNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertStringToComboNode", + display_name="Convert String to Combo", + category="logic", + inputs=[io.String.Input("string")], + outputs=[io.Combo.Output()], + ) + + @classmethod + def execute(cls, string: str) -> io.NodeOutput: + return io.NodeOutput(string) + +class InvertBooleanNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="InvertBooleanNode", + display_name="Invert Boolean", + category="logic", + inputs=[io.Boolean.Input("boolean")], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, boolean: bool) -> io.NodeOutput: + return io.NodeOutput(not boolean) + class LogicExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - # SwitchNode, + SwitchNode, + CustomComboNode, + # SoftSwitchNode, + # ConvertStringToComboNode, # DCTestNode, # AutogrowNamesTestNode, # AutogrowPrefixTestNode, + # ComboOutputTestNode, + # InvertBooleanNode, ] async def comfy_entrypoint() -> LogicExtension: diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ca2cdeb50..01afa13a1 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,11 +4,15 @@ import torch import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management +from comfy_extras.nodes_latent import reshape_latent_to import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod @@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode): s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width > height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +class ResizeImageMaskNode(io.ComfyNode): + + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + category="transform", + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, + ]), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + raise ValueError(f"Unsupported resize type: {selected_type}") + +def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: + if len(images) == 0: + return None + # first, get the max channels count + max_channels = max(image.shape[-1] for image in images) + # then, pad all images to have the same channels count + padded_images: list[torch.Tensor] = [] + for image in images: + if image.shape[-1] < max_channels: + padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0)) + else: + padded_images.append(image) + # resize all images to be the same size as the first image + resized_images: list[torch.Tensor] = [] + first_image_shape = padded_images[0].shape + for image in padded_images: + if image.shape[1:] != first_image_shape[1:]: + resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1)) + else: + resized_images.append(image) + # batch the images in the format [b, h, w, c] + return torch.cat(resized_images, dim=0) + +def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None: + if len(masks) == 0: + return None + # resize all masks to be the same size as the first mask + resized_masks: list[torch.Tensor] = [] + first_mask_shape = masks[0].shape + for mask in masks: + if mask.shape[1:] != first_mask_shape[1:]: + mask = init_image_mask_input(mask, is_type_image=False) + mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center") + resized_masks.append(finalize_image_mask_input(mask, is_type_image=False)) + else: + resized_masks.append(mask) + # batch the masks in the format [b, h, w] + return torch.cat(resized_masks, dim=0) + +def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None: + if len(latents) == 0: + return None + samples_out = latents[0].copy() + samples_out["batch_index"] = [] + first_samples = latents[0]["samples"] + tensors: list[torch.Tensor] = [] + for latent in latents: + # first, deal with latent tensors + tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False)) + # next, deal with batch_index + samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])])) + samples_out["samples"] = torch.cat(tensors, dim=0) + return samples_out + +class BatchImagesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50) + return io.Schema( + node_id="BatchImagesNode", + display_name="Batch Images", + category="image", + inputs=[ + io.Autogrow.Input("images", template=autogrow_template) + ], + outputs=[ + io.Image.Output() + ] + ) + + @classmethod + def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_images(list(images.values()))) + +class BatchMasksNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) + return io.Schema( + node_id="BatchMasksNode", + display_name="Batch Masks", + category="mask", + inputs=[ + io.Autogrow.Input("masks", template=autogrow_template) + ], + outputs=[ + io.Mask.Output() + ] + ) + + @classmethod + def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_masks(list(masks.values()))) + +class BatchLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) + return io.Schema( + node_id="BatchLatentsNode", + display_name="Batch Latents", + category="latent", + inputs=[ + io.Autogrow.Input("latents", template=autogrow_template) + ], + outputs=[ + io.Latent.Output() + ] + ) + + @classmethod + def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_latents(list(latents.values()))) + +class BatchImagesMasksLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent]) + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("input", matchtype_template), + prefix="input", min=1, max=50) + return io.Schema( + node_id="BatchImagesMasksLatentsNode", + display_name="Batch Images/Masks/Latents", + category="util", + inputs=[ + io.Autogrow.Input("inputs", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(id=None, template=matchtype_template) + ] + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + batched = None + values = list(inputs.values()) + # latents + if isinstance(values[0], dict): + batched = batch_latents(values) + # images + elif is_image(values[0]): + batched = batch_images(values) + # masks + else: + batched = batch_masks(values) + return io.NodeOutput(batched) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension): Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, + BatchImagesNode, + BatchMasksNode, + BatchLatentsNode, + # BatchImagesMasksLatentsNode, ] async def comfy_entrypoint() -> PostProcessingExtension: diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 5a1aeba80..937321800 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -66,7 +66,7 @@ class Float(io.ComfyNode): display_name="Float", category="utils/primitive", inputs=[ - io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], outputs=[io.Float.Output()], ) diff --git a/execution.py b/execution.py index 0c239efd7..38159b1f4 100644 --- a/execution.py +++ b/execution.py @@ -79,7 +79,7 @@ class IsChangedCache: # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: @@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} + hidden_inputs_v3 = {} + valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) - else: - valid_inputs = class_def.INPUT_TYPES() + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) input_data_all = {} missing_keys = {} - hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [input_data] if is_v3: - if schema.hidden: - if io.Hidden.prompt in schema.hidden: + if hidden is not None: + if io.Hidden.prompt.name in hidden: hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} - if io.Hidden.dynprompt in schema.hidden: + if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt - if io.Hidden.extra_pnginfo in schema.hidden: + if io.Hidden.extra_pnginfo.name in hidden: hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) - if io.Hidden.unique_id in schema.hidden: + if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id - if io.Hidden.auth_token_comfy_org in schema.hidden: + if io.Hidden.auth_token_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) - if io.Hidden.api_key_comfy_org in schema.hidden: + if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) else: if "hidden" in valid_inputs: @@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no resources or state, just create clone + # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() @@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) + # for check_lazy_status, the returned data should include the original key of the input + v3_data_lazy = v3_data.copy() + v3_data_lazy["create_dynamic_tuple"] = True + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -756,10 +758,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors = [] valid = True + v3_data = None validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): - class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + obj_class: _io._ComfyNodeBaseInternal + class_inputs = obj_class.INPUT_TYPES() + class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: @@ -779,10 +784,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): assert extra_info is not None if x not in inputs: if input_category == "required": + details = f"{x}" if not v3_data else x.split(".")[-1] error = { "type": "required_input_missing", "message": "Required input is missing", - "details": f"{x}", + "details": details, "extra_info": { "input_name": x } @@ -916,8 +922,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if isinstance(input_type, list): - combo_options = input_type + if isinstance(input_type, list) or input_type == io.Combo.io_type: + if input_type == io.Combo.io_type: + combo_options = extra_info.get("options", []) + else: + combo_options = input_type if val not in combo_options: input_config = info list_info = "" diff --git a/nodes.py b/nodes.py index 7d83ecb21..d9e4ebd91 100644 --- a/nodes.py +++ b/nodes.py @@ -1863,6 +1863,7 @@ class ImageBatch: FUNCTION = "batch" CATEGORY = "image" + DEPRECATED = True def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: From 6ca3d5c011bc15737131eb665939ae0a39a74254 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:12:38 +0200 Subject: [PATCH 042/181] fix(api-nodes-vidu): preserve percent-encoding for signed URLs (#11564) --- comfy_api_nodes/util/_helpers.py | 20 ++++++++++++++++++++ comfy_api_nodes/util/download_helpers.py | 3 ++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 491e6b6a8..648defe3d 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,16 +1,22 @@ import asyncio import contextlib import os +import re import time from collections.abc import Callable from io import BytesIO +from yarl import URL + from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted +_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits +_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits + def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" @@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) + + +def to_aiohttp_url(url: str) -> URL: + """If `url` appears to be already percent-encoded (contains at least one valid %HH + escape and no malformed '%' sequences) and contains no raw whitespace/control + characters preserve the original encoding byte-for-byte (important for signed/presigned URLs). + Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed.""" + if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url): + # Avoid encoded=True if URL contains raw whitespace/control chars + return URL(url) + if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url): + # Preserve encoding only if it appears pre-encoded AND has no invalid % sequences + return URL(url, encoded=True) + return URL(url) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 3e0d0352d..4668d14a9 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -19,6 +19,7 @@ from ._helpers import ( get_auth_header, is_processing_interrupted, sleep_with_interrupt, + to_aiohttp_url, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted @@ -94,7 +95,7 @@ async def download_url_to_bytesio( monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(url, headers=headers)) + req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: From 236b9e211d5093b33acbe1918f56a6bfb4a5cf17 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 1 Jan 2026 05:38:39 +0800 Subject: [PATCH 043/181] chore: update workflow templates to v0.7.65 (#11579) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8b670b813..3a05799eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.64 +comfyui-workflow-templates==0.7.65 comfyui-embedded-docs==0.3.1 torch torchsde From d622a618749b603531b753cef286a6051dd85565 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 31 Dec 2025 14:38:36 -0800 Subject: [PATCH 044/181] Refactor: move clip_preprocess to comfy.clip_model (#11586) --- comfy/clip_model.py | 19 +++++++++++++++++++ comfy/clip_vision.py | 22 ++-------------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5..e88872728 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -2,6 +2,25 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 447b1ce4a..d5fc53497 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os -import torch import json import logging @@ -17,24 +16,7 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): - image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) - std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) - if not (image.shape[2] == size and image.shape[3] == size): - if crop: - scale = (size / min(image.shape[2], image.shape[3])) - scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) - else: - scale_size = (size, size) - - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] - image = torch.clip((255. * image), 0, 255).round() / 255.0 - return (image - mean.view([3,1,1])) / std.view([3,1,1]) +clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, @@ -73,7 +55,7 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() From 1bdc9a947f578733f81c9ae894a5acd5809c7a66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 31 Dec 2025 16:29:55 -0800 Subject: [PATCH 045/181] Remove duplicate import of model_management (#11587) --- comfy/text_encoders/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ed29e014d..faa4e1de8 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -8,7 +8,6 @@ from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit -import comfy.model_management from . import qwen_vl @dataclass From 65cfcf5b1bb0d0618fef7bee08ee64397be5c434 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:06:14 -0800 Subject: [PATCH 046/181] New Year ruff cleanup. (#11595) --- app/model_manager.py | 4 ++-- comfy/hooks.py | 3 ++- comfy/ldm/chroma_radiance/model.py | 2 +- comfy/ldm/hunyuan_video/upsampler.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/ema.py | 4 ++-- comfy/ldm/util.py | 2 +- comfy/taesd/taehv.py | 6 ++++-- comfy_execution/graph.py | 6 +++--- comfy_extras/nodes_apg.py | 3 ++- comfy_extras/nodes_wan.py | 2 +- nodes.py | 6 ++++-- pyproject.toml | 4 ++++ server.py | 6 +++--- 14 files changed, 35 insertions(+), 22 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index ab36bca74..f124d1117 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -44,7 +44,7 @@ class ModelFileManager: @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @@ -55,7 +55,7 @@ class ModelFileManager: path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) - if not folder_name in folder_paths.folder_names_and_paths: + if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/comfy/hooks.py b/comfy/hooks.py index 9d0731072..1a76c7ba4 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -527,7 +527,8 @@ class HookKeyframeGroup: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further - else: break + else: + break # update steps current context is used self._current_used_steps += 1 # update current timestep this was performed on diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d173889..4fb56165e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -270,7 +270,7 @@ class ChromaRadiance(Chroma): bad_keys = tuple( k for k, v in overrides.items() - if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 85f515f67..d9e76922f 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,7 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management, model_patcher +import model_management +import model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5..1ae3ef034 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -394,7 +394,8 @@ class Model(nn.Module): attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -548,7 +549,8 @@ class Encoder(nn.Module): conv3d=False, time_compress=None, **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..96ee6e895 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -45,7 +45,7 @@ class LitEma(nn.Module): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -54,7 +54,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b4721..304936ff4 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -71,7 +71,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 3dfe1e4d4..0e5f9a378 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -154,7 +154,8 @@ class TAEHV(nn.Module): self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] if x.shape[1] % 4 != 0: # pad at end to multiple of 4 @@ -167,5 +168,6 @@ class TAEHV(nn.Module): def decode(self, x, **kwargs): x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 8fc5846b7..9d170b16e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -207,15 +207,15 @@ class ExecutionList(TopologicalSort): return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index f27ae7da8..b9df2dcc9 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -55,7 +55,8 @@ class APG(io.ComfyNode): def pre_cfg_function(args): nonlocal running_avg, prev_sigma - if len(args["conds_out"]) == 1: return args["conds_out"] + if len(args["conds_out"]) == 1: + return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index b0bd471bf..d32aad98e 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -817,7 +817,7 @@ def get_sample_indices(original_fps, if required_duration > total_frames / original_fps: raise ValueError("required_duration must be less than video length") - if not fixed_start is None and fixed_start >= 0: + if fixed_start is not None and fixed_start >= 0: start_frame = fixed_start else: max_start = total_frames - required_origin_frames diff --git a/nodes.py b/nodes.py index d9e4ebd91..eae2f0086 100644 --- a/nodes.py +++ b/nodes.py @@ -2242,8 +2242,10 @@ async def init_external_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": + continue + if module_path.endswith(".disabled"): + continue if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue diff --git a/pyproject.toml b/pyproject.toml index bc1467941..60378de1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,16 @@ lint.select = [ "N805", # invalid-first-argument-name-for-method "S307", # suspicious-eval-usage "S102", # exec + "E", "T", # print-usage "W", # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] + +lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"] + exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.pylint] diff --git a/server.py b/server.py index c27f8be7d..70c8b5e3b 100644 --- a/server.py +++ b/server.py @@ -324,7 +324,7 @@ class PromptServer(): @routes.get("/models/{folder}") async def get_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = folder_paths.get_filename_list(folder) return web.json_response(files) @@ -579,7 +579,7 @@ class PromptServer(): folder_name = request.match_info.get("folder_name", None) if folder_name is None: return web.Response(status=404) - if not "filename" in request.rel_url.query: + if "filename" not in request.rel_url.query: return web.Response(status=404) filename = request.rel_url.query["filename"] @@ -593,7 +593,7 @@ class PromptServer(): if out is None: return web.Response(status=404) dt = json.loads(out) - if not "__metadata__" in dt: + if "__metadata__" not in dt: return web.Response(status=404) return web.json_response(dt["__metadata__"]) From 9e5f677746463228e35ac6a08f308d758ed620d5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:35:34 +0200 Subject: [PATCH 047/181] Ignore all frames except the first one for MPO format. (#11569) --- nodes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index eae2f0086..662907ae6 100644 --- a/nodes.py +++ b/nodes.py @@ -1663,8 +1663,6 @@ class LoadImage: output_masks = [] w, h = None, None - excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) @@ -1692,7 +1690,10 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1 and img.format not in excluded_formats: + if img.format == "MPO": + break # ignore all frames except the first one for MPO format + + if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: From 303b1735f8785c0d1f947af965567850ca413f61 Mon Sep 17 00:00:00 2001 From: throttlekitty Date: Fri, 2 Jan 2026 01:37:37 -0700 Subject: [PATCH 048/181] Give Mahiro CFG a more appropriate display name (#11580) --- comfy_extras/nodes_mahiro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 07b3353f4..6459ca8c1 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + display_name="Mahiro CFG", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ From f2fda021ab179ba31d9175698b82474a5dd14359 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 2 Jan 2026 13:18:43 +0200 Subject: [PATCH 049/181] Tripo3D: pass face_limit parameter only when it differs from default (#11601) --- comfy_api_nodes/nodes_tripo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index bd3c24fb3..e72f8e96a 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, geometry_quality=geometry_quality, auto_size=True, quad=quad, @@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode): texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, auto_size=True, quad=quad, ), @@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): texture_quality=texture_quality, geometry_quality=geometry_quality, texture_alignment=texture_alignment, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, quad=quad, ), ) From 9a552df898ec57f066784cc1f7c475644099b3c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 2 Jan 2026 17:28:10 -0800 Subject: [PATCH 050/181] Remove leftover scaled_fp8 key. (#11603) --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 8d4e2b445..e4162d7ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): out_sd = {} layers = {} for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue From 53e762a3af9502ebe61a60eb2d39d783fe8d012b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:28:38 -0800 Subject: [PATCH 051/181] Print memory summary on OOM to help with debugging. (#11613) --- comfy/model_management.py | 4 ++++ execution.py | 1 + 2 files changed, 5 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 87baedd73..2501cecb7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1542,6 +1542,10 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def debug_memory_summary(): + if is_amd() or is_nvidia(): + return torch.cuda.memory.memory_summary() + return "" #TODO: might be cleaner to put this somewhere else import threading diff --git a/execution.py b/execution.py index 38159b1f4..648f204ec 100644 --- a/execution.py +++ b/execution.py @@ -601,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if isinstance(ex, comfy.model_management.OOM_EXCEPTION): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." + logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() From acbf08cd60fade74b2e9e5009fa0dcad9538356b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 4 Jan 2026 09:05:02 +0200 Subject: [PATCH 052/181] feat(api-nodes): add support for 720p resolution for Kling Omni nodes (#11604) --- comfy_api_nodes/nodes_kling.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 58259e029..9c707a339 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -807,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("duration", options=[5, 10]), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -826,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt: str, aspect_ratio: str, duration: int, + resolution: str = "1080p", ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=2500) response = await sync_op( @@ -837,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -872,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): optional=True, tooltip="Up to 6 additional reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -893,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): first_frame: Input.Image, end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -936,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -964,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): "reference_images", tooltip="Up to 7 reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -984,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio: str, duration: int, reference_images: Input.Image, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1005,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1036,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1058,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): reference_video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1090,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): duration=str(duration), image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1119,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1139,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode): video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1171,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode): duration=None, image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) From 38d049382533c6662d815b08ca3395e96cca9f57 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 16:13:50 -0800 Subject: [PATCH 053/181] Fix case where upscale model wouldn't be moved to cpu. (#11633) --- comfy_extras/nodes_upscale_model.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4d62b87be..ed587851c 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode): overlap = 32 oom = True - while oom: - try: - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) - oom = False - except model_management.OOM_EXCEPTION as e: - tile //= 2 - if tile < 128: - raise e + try: + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e + finally: + upscale_model.to("cpu") - upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return io.NodeOutput(s) From f2b002372b71cf0671a4cf1fa539e1c386d727e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 22:58:59 -0800 Subject: [PATCH 054/181] Support the LTXV 2 model. (#11632) --- comfy/latent_formats.py | 3 + comfy/ldm/lightricks/av_model.py | 837 ++++++++++++++++ comfy/ldm/lightricks/embeddings_connector.py | 305 ++++++ comfy/ldm/lightricks/latent_upsampler.py | 292 ++++++ comfy/ldm/lightricks/model.py | 715 +++++++++++--- comfy/ldm/lightricks/symmetric_patchifier.py | 87 +- comfy/ldm/lightricks/vae/audio_vae.py | 286 ++++++ .../vae/causal_audio_autoencoder.py | 909 ++++++++++++++++++ comfy/ldm/lightricks/vocoders/vocoder.py | 213 ++++ comfy/model_base.py | 57 +- comfy/model_detection.py | 2 +- comfy/sd.py | 9 +- comfy/supported_models.py | 17 +- comfy/text_encoders/llama.py | 79 ++ comfy/text_encoders/lt.py | 111 +++ comfy/utils.py | 2 +- comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_hunyuan.py | 15 +- comfy_extras/nodes_lt.py | 188 +++- comfy_extras/nodes_lt_audio.py | 183 ++++ comfy_extras/nodes_lt_upsampler.py | 75 ++ nodes.py | 10 +- pyproject.toml | 2 +- 23 files changed, 4214 insertions(+), 185 deletions(-) create mode 100644 comfy/ldm/lightricks/av_model.py create mode 100644 comfy/ldm/lightricks/embeddings_connector.py create mode 100644 comfy/ldm/lightricks/latent_upsampler.py create mode 100644 comfy/ldm/lightricks/vae/audio_vae.py create mode 100644 comfy/ldm/lightricks/vae/causal_audio_autoencoder.py create mode 100644 comfy/ldm/lightricks/vocoders/vocoder.py create mode 100644 comfy_extras/nodes_lt_audio.py create mode 100644 comfy_extras/nodes_lt_upsampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e..9bbe30b53 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -407,6 +407,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + pass + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000..759535501 --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,837 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + LTXVModel, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +import comfy.ldm.common_dit + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(6, a_dim, device=device, dtype=dtype) + ) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + v_context=None, + a_context=None, + attention_mask=None, + v_timestep=None, + a_timestep=None, + v_pe=None, + a_pe=None, + v_cross_pe=None, + a_cross_pe=None, + v_cross_scale_shift_timestep=None, + a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, + a_cross_gate_timestep=None, + transformer_options=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) + ) + + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa + vx += self.attn2( + comfy.ldm.common_dit.rms_norm(vx), + context=v_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) + ) + + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + ax += ( + self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + * agate_msa + ) + ax += self.audio_attn2( + comfy.ldm.common_dit.rms_norm(ax), + context=a_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + # norm3 + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + a_cross_scale_shift_timestep, + a_cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + v_cross_scale_shift_timestep, + v_cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + + shift_ca_video_hidden_states_a2v + ) + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + + shift_ca_audio_hidden_states_a2v + ) + vx += ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=v_cross_pe, + k_pe=a_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_a2v + ) + + del gate_out_a2v + del scale_ca_video_hidden_states_a2v,\ + shift_ca_video_hidden_states_a2v,\ + scale_ca_audio_hidden_states_a2v,\ + shift_ca_audio_hidden_states_a2v,\ + + if run_v2a: + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + + shift_ca_audio_hidden_states_v2a + ) + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + + shift_ca_video_hidden_states_v2a + ) + ax += ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=a_cross_pe, + k_pe=v_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_v2a + ) + + del gate_out_v2a + del scale_ca_video_hidden_states_v2a,\ + shift_ca_video_hidden_states_v2a,\ + scale_ca_audio_hidden_states_v2a,\ + shift_ca_audio_hidden_states_v2a + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) + ) + + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + vx += self.ff(vx_scaled) * vgate_mlp + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) + ) + + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + ax += self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) + v_embedded_timestep = v_embedded_timestep.view( + batch_size, -1, v_embedded_timestep.shape[-1] + ) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + if a_timestep is not None: + a_timestep = a_timestep * self.timestep_scale_multiplier + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + a_timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view( + batch_size, -1, a_embedded_timestep.shape[-1] + ) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ] + cross_av_timestep_ss = list( + [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] + ) + else: + a_timestep = timestep + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + + return [v_timestep, a_timestep, cross_av_timestep_ss], [ + v_embedded_timestep, + a_embedded_timestep, + ] + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + v_context, a_context = torch.split( + context, int(context.shape[-1] / 2), len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.audio_caption_projection is not None: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, ax.shape[-1]) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000..f7a43f3c3 --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,305 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + * 2.0 + - 1.0 + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers, (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000..78ed7653f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f..d61e19d6e 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,47 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +73,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module): post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module): embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -222,23 +282,68 @@ class GELU_approx(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +359,11 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,8 +373,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -277,14 +384,34 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) @@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module): return x def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - return freqs_cis + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) - caption_channels=4096, - num_layers=28, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ + + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels + + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device + ) self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, ) + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe + + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask + + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x + + +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + # No additional components needed for LTXV beyond base class + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, ) - for d in range(num_layers) + for _ in range(self.num_layers) ] ) + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) - - self.patchifier = SymmetricPatchifier(1) - - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) - - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - - orig_shape = list(x.shape) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module): causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + return x, pixel_coords, additional_args + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) @@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module): transformer_options=transformer_options, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9f..8f9a41186 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ class Patchifier(ABC): torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier): q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000..a9111d3bd --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,286 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + @staticmethod + def normalize_amplitude( + waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 + ) -> torch.Tensor: + waveform = waveform - waveform.mean(dim=2, keepdim=True) + peak = torch.max(torch.abs(waveform)) + eps + scale = peak.clamp(max=max_amplitude) / peak + return waveform * scale + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + waveform = self.normalize_amplitude(waveform) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000..f12b9bb53 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,909 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self._guess_config() + + # Extract encoder and decoder configs from the new format + model_config = config.get("model", {}).get("params", {}) + variables_config = config.get("variables", {}) + + self.sampling_rate = variables_config.get( + "sampling_rate", + model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def _guess_config(self): + encoder_config = { + # Required parameters - based on ltx-video-av-1679000 model metadata + "ch": 128, + "out_ch": 8, + "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] + "num_res_blocks": 2, + "attn_resolutions": [], # Based on metadata: empty list, no attention + "dropout": 0.0, + "resamp_with_conv": True, + "in_channels": 2, # stereo + "resolution": 256, + "z_channels": 8, + "double_z": True, + "attn_type": "vanilla", + "mid_block_add_attention": False, # Based on metadata: false + "norm_type": "pixel", + "causality_axis": "height", # Based on metadata + "mel_bins": 64, # Based on metadata: mel_bins = 64 + } + + decoder_config = { + # Inherits encoder config, can override specific params + **encoder_config, + "out_ch": 2, # Stereo audio output (2 channels) + "give_pre_end": False, + "tanh_out": False, + } + + config = { + "_class_name": "CausalAudioAutoencoder", + "sampling_rate": 16000, + "model": { + "params": { + "encoder": encoder_config, + "decoder": decoder_config, + } + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000..b1f15f2c5 --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import numpy as np + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + resblock = config.get("resblock", "1") + + self.output_sample_rate = config.get("output_sample_rate") + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + out_channels = 2 if stereo else 1 + self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [6, 5, 2, 2, 2], + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c4f3c0639..49efd700b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -946,7 +947,7 @@ class GenmoMochi(BaseModel): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -977,6 +978,60 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 539e296ed..0853b3aec 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 diff --git a/comfy/sd.py b/comfy/sd.py index 7de7dd9c6..32157e18b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1041,7 +1041,8 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 - JINA_CLIP_2 = 18 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 def detect_te_model(sd): @@ -1067,6 +1068,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B @@ -1271,6 +1274,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..ee9a79001 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.055 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index faa4e1de8..76731576b 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -7,6 +7,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit +import comfy.clip_model from . import qwen_vl @@ -188,6 +189,31 @@ class Gemma3_4B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + mm_tokens_per_image = 256 + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() @@ -520,6 +546,41 @@ class Llama2_(nn.Module): return x, intermediate + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + + class BaseLlama: def get_input_embeddings(self): return self.model.embed_tokens @@ -636,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_12B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67..2c2d453e8 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,11 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import torch +import comfy.utils class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +20,110 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): + text = llama_template.format(text) + text_tokens = super().tokenize_with_weights(text, return_word_ids) + embed_count = 0 + for k in text_tokens: + tt = text_tokens[k] + for r in tt: + for i in range(len(r)): + if r[i][0] == 262144: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return text_tokens + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def set_clip_options(self, options): + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out_device = out.device + out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + + return out.to(out_device), pooled + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + return self.load_state_dict(sdo, strict=False) + + +def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return LTXAVTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index e4162d7ac..ffa98c9b1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443c..94ad5e8a8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1..ceff657d3 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode): "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode): "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb..b91a22309 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode): generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +159,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode): frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, - t[:, :, :num_prefix_frames], + t, strength, scale_factors, ) - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( - latent_image, - noise_mask, - t, - latent_idx, - strength, - ) - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode): preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000..b0b7000ef --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,183 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000..f99ba13fb --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/nodes.py b/nodes.py index 662907ae6..56b74ebe3 100644 --- a/nodes.py +++ b/nodes.py @@ -295,7 +295,11 @@ class VAEDecode: DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - images = vae.decode(samples["samples"]) + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + images = vae.decode(latent) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -970,7 +974,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2331,6 +2335,8 @@ async def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_mahiro.py", + "nodes_lt_upsampler.py", + "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", diff --git a/pyproject.toml b/pyproject.toml index 60378de1e..a7d159be9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "ComfyUI" version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.9" +requires-python = ">=3.10" [project.urls] homepage = "https://www.comfy.org/" From d1b9822f741843c64b2cbd8e1bcdd49794b182ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 23:27:31 -0800 Subject: [PATCH 055/181] Add LTXAVTextEncoderLoader node. (#11634) --- comfy_extras/nodes_lt_audio.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index b0b7000ef..2d3d103b4 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -169,6 +169,38 @@ class LTXVEmptyLatentAudio(io.ComfyNode): ) +class LTXAVTextEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXAVTextEncoderLoader", + display_name="LTXV Audio Text Encoder Loader", + category="advanced/loaders", + description="[Recipes]\n\nltxav: gemma 3 12B", + inputs=[ + io.Combo.Input( + "text_encoder", + options=folder_paths.get_filename_list("text_encoders"), + ), + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + ) + ], + outputs=[io.Clip.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, text_encoder, ckpt_name, device="default"): + clip_type = comfy.sd.CLIPType.LTXV + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) + clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + return io.NodeOutput(clip) + + class LTXVAudioExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ @@ -176,6 +208,7 @@ class LTXVAudioExtension(ComfyExtension): LTXVAudioVAEEncode, LTXVAudioVAEDecode, LTXVEmptyLatentAudio, + LTXAVTextEncoderLoader, ] From d157c3299d6f9e1b57981bdb4931f1d7129e4e8d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 00:48:31 -0800 Subject: [PATCH 056/181] Refactor module_size function. (#11637) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2501cecb7..7f5a8aee9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -456,7 +456,7 @@ def module_size(module): sd = module.state_dict() for k in sd: t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem += t.nbytes return module_mem class LoadedModel: From 4f3f9e72a9d0c15d00c0c362b8e90f1db5af6cfb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 02:41:23 -0800 Subject: [PATCH 057/181] Fix name. (#11638) --- comfy_extras/nodes_lt_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 2d3d103b4..26b0160d2 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -187,7 +187,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode): options=folder_paths.get_filename_list("checkpoints"), ) ], - outputs=[io.Clip.Output(display_name="Audio VAE")], + outputs=[io.Clip.Output()], ) @classmethod From 6da00dd899e3ee6f2a0a8163b080a9f373395025 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:48:58 -0800 Subject: [PATCH 058/181] Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635) --------- Co-authored-by: Jedrzej Kosinski --- .github/workflows/test-build.yml | 2 +- .github/workflows/test-launch.yml | 4 +- comfy/model_management.py | 4 +- comfy/ops.py | 164 +++-- comfy/quant_ops.py | 641 +++--------------- requirements.txt | 1 + .../comfy_quant/test_mixed_precision.py | 12 +- tests-unit/comfy_quant/test_quant_registry.py | 190 ------ 8 files changed, 223 insertions(+), 795 deletions(-) delete mode 100644 tests-unit/comfy_quant/test_quant_registry.py diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..9160242e9 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index fd70aff23..ef0d3f123 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -32,7 +32,9 @@ jobs: working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | - if grep -qE "Exception|Error" console_output.log; then + grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log + cat console_output_filtered.log + if grep -qE "Exception|Error" console_output_filtered.log; then echo "Unhandled exception/error found in server log." exit 1 fi diff --git a/comfy/model_management.py b/comfy/model_management.py index 7f5a8aee9..22f4de044 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1156,7 +1156,7 @@ def pin_memory(tensor): if not tensor.is_contiguous(): return False - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1183,7 +1183,7 @@ def unpin_memory(tensor): return False ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes size_stored = PINNED_MEMORY.get(ptr, None) if size_stored is None: diff --git a/comfy/ops.py b/comfy/ops.py index 16889bb82..f5e1e9230 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -412,26 +412,34 @@ def fp8_linear(self, input): return None input_dtype = input.dtype + input_shape = input.shape + tensor_3d = input.ndim == 3 - if input.ndim == 3 or input.ndim == 2: - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) + if input.ndim != 2: + return None + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - # Wrap weight in QuantizedTensor - this enables unified dispatch - # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! - layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) - o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.to(dtype).contiguous() + layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) + quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) - uncast_bias_weight(self, w, bias, offload_stream) - return o + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - return None + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_3d: + o = o.reshape((input_shape[0], input_shape[1], w.shape[0])) + + return o class fp8_ops(manual_cast): class Linear(manual_cast.Linear): @@ -477,7 +485,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, QUANT_ALGOS +from .quant_ops import ( + QuantizedTensor, + QUANT_ALGOS, + TensorCoreFP8Layout, + get_layout_class, +) def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): @@ -497,14 +510,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - if dtype is None: - dtype = MixedPrecisionOps._compute_dtype - - self.factory_kwargs = {"device": device, "dtype": dtype} + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features - self._has_bias = bias + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm @@ -512,6 +526,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -529,14 +553,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - dtype = self.factory_kwargs["dtype"] - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) - if dtype != MixedPrecisionOps._compute_dtype: - self.comfy_cast_weights = True - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) - else: - self.register_parameter("bias", None) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) if not self._full_precision_mm: @@ -547,31 +564,46 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) - weight_scale_key = f"{prefix}weight_scale" - scale = state_dict.pop(weight_scale_key, None) - if scale is not None: - scale = scale.to(device) - layout_params = { - 'scale': scale, - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } + # Load format-specific parameters + if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: + # FP8: single tensor scale + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if scale is not None: - manually_loaded_keys.append(weight_scale_key) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "nvfp4": + # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.float8_e4m3fn) + + if tensor_scale is None or block_scale is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + else: + raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) - else: - self.register_parameter("bias", None) - for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue # Already handled above + param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -588,7 +620,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) if isinstance(self.weight, QuantizedTensor): - sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + layout_cls = self.weight._layout_cls + + # Check if it's any FP8 variant (E4M3 or E5M2) + if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): + sd["{}weight_scale".format(prefix)] = self.weight._params.scale + elif layout_cls == "TensorCoreNVFP4Layout": + sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale + sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + quant_conf = {"format": self.quant_format} if self._full_precision_mm: quant_conf["full_precision_matrix_mult"] = True @@ -607,12 +647,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def forward(self, input, *args, **kwargs): run_every_op() + input_shape = input.shape + tensor_3d = input.ndim == 3 + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) + + if input.ndim != 2: + # Fall back to comfy_cast_weights for non-2D tensors + return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + + # dtype is now implicit in the layout class + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) + + output = self._forward(input, self.weight, self.bias) + + # Reshape output back to 3D if input was 3D + if tensor_3d: + output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + + return output def convert_weight(self, weight, inplace=False, **kwargs): if isinstance(weight, QuantizedTensor): @@ -622,7 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + # dtype is now implicit in the layout class + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..cd737726f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,580 +1,133 @@ import torch import logging -from typing import Tuple, Dict + +try: + import comfy_kitchen as ck + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + TensorCoreFP8Layout as _CKFp8Layout, + TensorCoreNVFP4Layout, # Direct import, no wrapper needed + register_layout_op, + register_layout_class, + get_layout_class, + ) + _CK_AVAILABLE = True + ck.registry.disable("triton") + for k, v in ck.list_backends().items(): + logging.info(f"Found comfy_kitchen backend {k}: {v}") +except ImportError as e: + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class TensorCoreNVFP4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None + import comfy.float -_LAYOUT_REGISTRY = {} -_GENERIC_UTILS = {} - - -def register_layout_op(torch_op, layout_type): - """ - Decorator to register a layout-specific operation handler. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) - layout_type: Layout class (e.g., TensorCoreFP8Layout) - Example: - @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) - def fp8_linear(func, args, kwargs): - # FP8-specific linear implementation - ... - """ - def decorator(handler_func): - if torch_op not in _LAYOUT_REGISTRY: - _LAYOUT_REGISTRY[torch_op] = {} - _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func - return handler_func - return decorator - - -def register_generic_util(torch_op): - """ - Decorator to register a generic utility that works for all layouts. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - - Example: - @register_generic_util(torch.ops.aten.detach.default) - def generic_detach(func, args, kwargs): - # Works for any layout - ... - """ - def decorator(handler_func): - _GENERIC_UTILS[torch_op] = handler_func - return handler_func - return decorator - - -def _get_layout_from_args(args): - for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type - return None - - -def _move_layout_params_to_device(params, device): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.to(device=device) - else: - new_params[k] = v - return new_params - - -def _copy_layout_params(params): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.clone() - else: - new_params[k] = v - return new_params - -def _copy_layout_params_inplace(src, dst, non_blocking=False): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - dst[k].copy_(v, non_blocking=non_blocking) - else: - dst[k] = v - -class QuantizedLayout: - """ - Base class for quantization layouts. - - A layout encapsulates the format-specific logic for quantization/dequantization - and provides a uniform interface for extracting raw tensors needed for computation. - - New quantization formats should subclass this and implement the required methods. - """ - @classmethod - def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: - raise NotImplementedError(f"{cls.__name__} must implement quantize()") - - @staticmethod - def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError("TensorLayout must implement dequantize()") - - @classmethod - def get_plain_tensors(cls, qtensor) -> torch.Tensor: - raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") - - -class QuantizedTensor(torch.Tensor): - """ - Universal quantized tensor that works with any layout. - - This tensor subclass uses a pluggable layout system to support multiple - quantization formats (FP8, INT4, INT8, etc.) without code duplication. - - The layout_type determines format-specific behavior, while common operations - (detach, clone, to) are handled generically. - - Attributes: - _qdata: The quantized tensor data - _layout_type: Layout class (e.g., TensorCoreFP8Layout) - _layout_params: Dict with layout-specific params (scale, zero_point, etc.) - """ - - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. - - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) - - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params - - def __repr__(self): - layout_name = self._layout_type - param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) - return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - - @property - def layout_type(self): - return self._layout_type - - def __tensor_flatten__(self): - """ - Tensor flattening protocol for proper device movement. - """ - inner_tensors = ["_qdata"] - ctx = { - "layout_type": self._layout_type, - } - - tensor_params = {} - non_tensor_params = {} - for k, v in self._layout_params.items(): - if isinstance(v, torch.Tensor): - tensor_params[k] = v - else: - non_tensor_params[k] = v - - ctx["tensor_param_keys"] = list(tensor_params.keys()) - ctx["non_tensor_params"] = non_tensor_params - - for k, v in tensor_params.items(): - attr_name = f"_layout_param_{k}" - object.__setattr__(self, attr_name, v) - inner_tensors.append(attr_name) - - return inner_tensors, ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): - """ - Tensor unflattening protocol for proper device movement. - Reconstructs the QuantizedTensor after device movement. - """ - layout_type = ctx["layout_type"] - layout_params = dict(ctx["non_tensor_params"]) - - for key in ctx["tensor_param_keys"]: - attr_name = f"_layout_param_{key}" - layout_params[key] = inner_tensors[attr_name] - - return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) - - @classmethod - def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) - return cls(qdata, layout_type, layout_params) - - def dequantize(self) -> torch.Tensor: - return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Step 1: Check generic utilities first (detach, clone, to, etc.) - if func in _GENERIC_UTILS: - return _GENERIC_UTILS[func](func, args, kwargs) - - # Step 2: Check layout-specific handlers (linear, matmul, etc.) - layout_type = _get_layout_from_args(args) - if layout_type and func in _LAYOUT_REGISTRY: - handler = _LAYOUT_REGISTRY[func].get(layout_type) - if handler: - return handler(func, args, kwargs) - - # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) - - @classmethod - def _dequant_and_fallback(cls, func, args, kwargs): - def dequant_arg(arg): - if isinstance(arg, QuantizedTensor): - return arg.dequantize() - elif isinstance(arg, (list, tuple)): - return type(arg)(dequant_arg(a) for a in arg) - return arg - - new_args = dequant_arg(args) - new_kwargs = dequant_arg(kwargs) - return func(*new_args, **new_kwargs) - - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self, *arg, **kwargs): - return self._qdata.is_contiguous(*arg, **kwargs) - - def storage(self): - return self._qdata.storage() - # ============================================================================== -# Generic Utilities (Layout-Agnostic Operations) +# FP8 Layouts with Comfy-Specific Extensions # ============================================================================== -def _create_transformed_qtensor(qt, transform_fn): - new_data = transform_fn(qt._qdata) - new_params = _copy_layout_params(qt._layout_params) - return QuantizedTensor(new_data, qt._layout_type, new_params) +class _TensorCoreFP8LayoutBase(_CKFp8Layout): + FP8_DTYPE = None # Must be overridden in subclass - -def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_layout is not None and target_layout != torch.strided: - logging.warning( - f"QuantizedTensor: layout change requested to {target_layout}, " - f"but not supported. Ignoring layout." - ) - - # Handle device transfer - current_device = qt._qdata.device - if target_device is not None: - # Normalize device for comparison - if isinstance(target_device, str): - target_device = torch.device(target_device) - if isinstance(current_device, str): - current_device = torch.device(current_device) - - if target_device != current_device: - logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - if target_dtype is not None: - new_params["orig_dtype"] = target_dtype - new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) - logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") - return new_qt - - logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") - return qt - - -@register_generic_util(torch.ops.aten.detach.default) -def generic_detach(func, args, kwargs): - """Detach operation - creates a detached copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.detach()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.clone.default) -def generic_clone(func, args, kwargs): - """Clone operation - creates a deep copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.clone()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._to_copy.default) -def generic_to_copy(func, args, kwargs): - """Device/dtype transfer operation - handles .to(device) calls.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - op_name="_to_copy" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype_layout) -def generic_to_dtype_layout(func, args, kwargs): - """Handle .to(device) calls using the dtype_layout variant.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - target_layout=kwargs.get('layout', None), - op_name="to" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.copy_.default) -def generic_copy_(func, args, kwargs): - qt_dest = args[0] - src = args[1] - non_blocking = args[2] if len(args) > 2 else False - if isinstance(qt_dest, QuantizedTensor): - if isinstance(src, QuantizedTensor): - # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) - qt_dest._layout_type = src._layout_type - orig_dtype = qt_dest._layout_params["orig_dtype"] - _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) - qt_dest._layout_params["orig_dtype"] = orig_dtype - else: - # Copy from regular tensor - just copy raw data - qt_dest._qdata.copy_(src) - return qt_dest - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) -def generic_has_compatible_shallow_copy_type(func, args, kwargs): - return True - - -@register_generic_util(torch.ops.aten.empty_like.default) -def generic_empty_like(func, args, kwargs): - """Empty_like operation - creates an empty tensor with the same quantized structure.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - # Create empty tensor with same shape and dtype as the quantized data - hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) - new_qdata = torch.empty_like(qt._qdata, **kwargs) - - # Handle device transfer for layout params - target_device = kwargs.get('device', new_qdata.device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - - # Update orig_dtype if dtype is specified - new_params['orig_dtype'] = hp_dtype - - return QuantizedTensor(new_qdata, qt._layout_type, new_params) - return func(*args, **kwargs) - -# ============================================================================== -# FP8 Layout + Operation Handlers -# ============================================================================== -class TensorCoreFP8Layout(QuantizedLayout): - """ - Storage format: - - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) - - scale: Scalar tensor (float32) for dequantization - - orig_dtype: Original dtype before quantization (for casting back) - """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if cls.FP8_DTYPE is None: + raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small tensor_info = torch.finfo(tensor.dtype) scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if scale is not None: - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is None: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) else: tensor = tensor * (1.0 / scale).to(tensor.dtype) + qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) else: - scale = torch.ones((), device=tensor.device, dtype=torch.float32) + qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) - if stochastic_rounding > 0: - tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) - else: - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) - tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) + return qdata, params - layout_params = { - 'scale': scale, - 'orig_dtype': orig_dtype - } - return tensor, layout_params - @staticmethod - def dequantize(qdata, scale, orig_dtype, **kwargs): - plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - plain_tensor.mul_(scale) - return plain_tensor +class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e4m3fn - @classmethod - def get_plain_tensors(cls, qtensor): - return qtensor._qdata, qtensor._layout_params['scale'] + +class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e5m2 + + +# Backward compatibility alias - default to E4M3 +TensorCoreFP8Layout = TensorCoreFP8E4M3Layout + + +# ============================================================================== +# Registry +# ============================================================================== + +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, - "comfy_tensor_layout": "TensorCoreFP8Layout", + "comfy_tensor_layout": "TensorCoreFP8E4M3Layout", + }, + "float8_e5m2": { + "storage_t": torch.float8_e5m2, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8E5M2Layout", + }, + "nvfp4": { + "storage_t": torch.uint8, + "parameters": {"weight_scale", "weight_scale_2", "input_scale"}, + "comfy_tensor_layout": "TensorCoreNVFP4Layout", + "group_size": 16, }, } -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, -} +# ============================================================================== +# Re-exports for backward compatibility +# ============================================================================== -@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") -def fp8_linear(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - weight_t = plain_weight.t() - - tensor_2d = False - if len(plain_input.shape) == 2: - tensor_2d = True - plain_input = plain_input.unsqueeze(1) - - input_shape = plain_input.shape - if len(input_shape) != 3: - return None - - try: - output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]).contiguous(), - weight_t, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - - if not tensor_2d: - output = output.reshape((-1, input_shape[1], weight.shape[0])) - - if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = scale_a * scale_b - output_params = { - 'scale': output_scale, - 'orig_dtype': input_tensor._layout_params['orig_dtype'] - } - return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) - else: - return output - - except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - - # Case 2: DQ Fallback - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - if isinstance(input_tensor, QuantizedTensor): - input_tensor = input_tensor.dequantize() - - return torch.nn.functional.linear(input_tensor, weight, bias) - -def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output - -@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") -def fp8_addmm(func, args, kwargs): - input_tensor = args[1] - weight = args[2] - bias = args[0] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - if isinstance(args[2], QuantizedTensor): - a[2] = args[2].dequantize() - - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") -def fp8_mm(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") -@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") -def fp8_func(func, args, kwargs): - input_tensor = args[0] - if isinstance(input_tensor, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - ar = list(args) - ar[0] = plain_input - return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) - return func(*args, **kwargs) +__all__ = [ + "QuantizedTensor", + "QuantizedLayout", + "TensorCoreFP8Layout", + "TensorCoreFP8E4M3Layout", + "TensorCoreFP8E5M2Layout", + "TensorCoreNVFP4Layout", + "QUANT_ALGOS", + "register_layout_op", +] diff --git a/requirements.txt b/requirements.txt index 3a05799eb..0ee152032 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 +comfy-kitchen>=0.2.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 3a54941e6..7b2eac940 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify scales were loaded - self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) - self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) + self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py deleted file mode 100644 index 9cb54ede8..000000000 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -def has_gpu(): - return torch.cuda.is_available() - -from comfy.cli_args import args -if not has_gpu(): - args.cpu = True - -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout - - -class TestQuantizedTensor(unittest.TestCase): - """Test the QuantizedTensor subclass with FP8 layout""" - - def test_creation(self): - """Test creating a QuantizedTensor with TensorCoreFP8Layout""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._layout_params['scale'], scale) - self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") - - def test_dequantize(self): - """Test explicit dequantization""" - - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - dequantized = qt.dequantize() - - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_from_float(self): - """Test creating QuantizedTensor from float tensor""" - float_tensor = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - - qt = QuantizedTensor.from_float( - float_tensor, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt.shape, (64, 32)) - - # Verify dequantization gives approximately original values - dequantized = qt.dequantize() - mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.1) - - -class TestGenericUtilities(unittest.TestCase): - """Test generic utility operations""" - - def test_detach(self): - """Test detach operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Detach should return a new QuantizedTensor - qt_detached = qt.detach() - - self.assertIsInstance(qt_detached, QuantizedTensor) - self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") - - def test_clone(self): - """Test clone operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Clone should return a new QuantizedTensor - qt_cloned = qt.clone() - - self.assertIsInstance(qt_cloned, QuantizedTensor) - self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") - - # Verify it's a deep copy - self.assertIsNot(qt_cloned._qdata, qt._qdata) - - @unittest.skipUnless(has_gpu(), "GPU not available") - def test_to_device(self): - """Test device transfer""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Moving to same device should work (CPU to CPU) - qt_cpu = qt.to('cpu') - - self.assertIsInstance(qt_cpu, QuantizedTensor) - self.assertEqual(qt_cpu.device.type, 'cpu') - self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') - - -class TestTensorCoreFP8Layout(unittest.TestCase): - """Test the TensorCoreFP8Layout implementation""" - - def test_quantize(self): - """Test quantization method""" - float_tensor = torch.randn(32, 64, dtype=torch.float32) - scale = torch.tensor(1.5) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertEqual(qdata.dtype, torch.float8_e4m3fn) - self.assertEqual(qdata.shape, float_tensor.shape) - self.assertIn('scale', layout_params) - self.assertIn('orig_dtype', layout_params) - self.assertEqual(layout_params['orig_dtype'], torch.float32) - - def test_dequantize(self): - """Test dequantization method""" - float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 - scale = torch.tensor(1.0) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - - # Should approximately match original - self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_q = QuantizedTensor.from_float( - a_fp32, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, QuantizedTensor) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -if __name__ == "__main__": - unittest.main() From 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:50:35 -0800 Subject: [PATCH 059/181] Use rope functions from comfy kitchen. (#11647) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From 161800241117fae7af90e0c938d0cf8cb2f2ddb1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 20:07:39 -0800 Subject: [PATCH 060/181] Revert "Use rope functions from comfy kitchen. (#11647)" (#11648) This reverts commit 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd. --- comfy/ldm/flux/math.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..6a22df8bc 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,7 +4,6 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) -try: - import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 -except: - logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From e14f3b661069971163ddc56036b0f486933b9162 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 6 Jan 2026 14:37:11 +0800 Subject: [PATCH 061/181] chore: update workflow templates to v0.7.66 (#11652) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ee152032..9c9c0e29e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.65 +comfyui-workflow-templates==0.7.66 comfyui-embedded-docs==0.3.1 torch torchsde From 96e0d0924e027248733bc6e0b8102dcdc8acde33 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:43:24 -0800 Subject: [PATCH 062/181] Add helpful message to portable. (#11671) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 2 +- .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 2 +- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause From 6ffc159bdd56d1ad73e954081def6a7f163e7a7f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:53:43 -0800 Subject: [PATCH 063/181] Update comfy-kitchen version to 0.2.1 (#11672) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c9c0e29e..22cb50e2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.0 +comfy-kitchen>=0.2.1 #non essential dependencies: kornia>=0.7.1 From c3c3e93c5bb3034175c17ef8beeb8fe8626c66ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 13:57:50 -0800 Subject: [PATCH 064/181] Use rope functions from comfy kitchen. (#11674) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/requirements.txt b/requirements.txt index 22cb50e2d..7798cb179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.1 +comfy-kitchen>=0.2.2 #non essential dependencies: kornia>=0.7.1 From c3566c0d765200068d26d0888f035504a50012f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 7 Jan 2026 06:28:29 +0800 Subject: [PATCH 065/181] chore: update workflow templates to v0.7.67 (#11667) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7798cb179..caad0026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.66 +comfyui-workflow-templates==0.7.67 comfyui-embedded-docs==0.3.1 torch torchsde From 023cf13721cac256c323e2226319b766d07b1f36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:33:03 -0800 Subject: [PATCH 066/181] Fix lowvram issue with ltxv2 text encoder. (#11675) --- comfy/ldm/lightricks/embeddings_connector.py | 2 +- comfy/text_encoders/lt.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index f7a43f3c3..06f5ada89 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -276,7 +276,7 @@ class Embeddings1DConnector(nn.Module): max(1024, hidden_states.shape[1]) / self.num_learnable_registers ) learnable_registers = torch.tile( - self.learnable_registers, (num_registers_duplications, 1) + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) ) hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 2c2d453e8..e5964e42b 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -86,17 +86,19 @@ class LTXAVTEModel(torch.nn.Module): ) def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) def reset_clip_options(self): self.gemma3_12b.reset_clip_options() + self.execution_device = None def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device - out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) From 6e9ee55cdd9e0eca6b5144063575b983f3311762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:41:27 -0800 Subject: [PATCH 067/181] Disable ltxav previews. (#11676) --- comfy/latent_formats.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9bbe30b53..cb4f52ce1 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -408,7 +408,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] class LTXAV(LTXV): - pass + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None class HunyuanVideo(LatentFormat): latent_channels = 16 From 2c03884f5fb7fa213161dfe1e9a09a8e8c4b6062 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:07:26 -0800 Subject: [PATCH 068/181] Skip fp4 matrix mult on devices that don't support it. (#11677) --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 22f4de044..928282092 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index f5e1e9230..8f9fdce36 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -493,11 +493,12 @@ from .quant_ops import ( ) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None @@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd @@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and From edee33f55ea27a1931475d3ea788fd6e9a81677b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 19:13:43 -0800 Subject: [PATCH 069/181] Disable comfy kitchen cuda if pytorch cuda less than 13 (#11681) --- comfy/quant_ops.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd737726f..5a17bc6f5 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -13,6 +13,13 @@ try: get_layout_class, ) _CK_AVAILABLE = True + if torch.version.cuda is None: + ck.registry.disable("cuda") + else: + cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) + if cuda_version < (13,): + ck.registry.disable("cuda") + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") From c5cfb34c07048350f472a9a4f1ccbf75a56ed38f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 20:51:45 -0800 Subject: [PATCH 070/181] Update comfy-kitchen version to 0.2.3 (#11685) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index caad0026a..bc8346bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.2 +comfy-kitchen>=0.2.3 #non essential dependencies: kornia>=0.7.1 From ce0000c4f2a7dba12324585dddb784b43e3cd3d0 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Tue, 6 Jan 2026 21:57:31 -0800 Subject: [PATCH 071/181] Force sequential execution in CI test jobs (#11687) Added max-parallel setting to enforce sequential execution in test jobs. --- .github/workflows/test-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index adfc5dd32..63df2dc3a 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,6 +20,7 @@ jobs: test-stable: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -74,6 +75,7 @@ jobs: test-unix-nightly: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux] From 79e94544bd7ec0cc7a4e5e6167907e7d781c4b76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 7 Jan 2026 08:04:50 +0200 Subject: [PATCH 072/181] feat(api-nodes): add WAN2.6 ReferenceToVideo (#11644) --- comfy_api_nodes/nodes_wan.py | 160 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 2 +- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 1675fd863..3e04786a9 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,7 +13,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_video_to_comfyapi, validate_audio_duration, + validate_video_duration, ) @@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel): audio_url: str | None = Field(None) +class Reference2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + reference_video_urls: list[str] = Field(...) + + class Txt2ImageParametersField(BaseModel): size: str = Field(...) n: int = Field(1, description="Number of images to generate.") # we support only value=1 @@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel): shot_type: str = Field("single") +class Reference2VideoParametersField(BaseModel): + size: str = Field(...) + duration: int = Field(5, ge=5, le=15) + shot_type: str = Field("single") + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(False) + + class Text2ImageTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2ImageInputField = Field(...) @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel): parameters: Image2VideoParametersField = Field(...) +class Reference2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Reference2VideoInputField = Field(...) + parameters: Reference2VideoParametersField = Field(...) + + class TaskCreationOutputField(BaseModel): task_id: str = Field(...) task_status: str = Field(...) @@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class WanReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanReferenceVideoApi", + display_name="Wan Reference to Video", + category="api node/video/Wan", + description="Use the character and voice from input videos, combined with a prompt, " + "to generate a new video that maintains character consistency.", + inputs=[ + IO.Combo.Input("model", options=["wan2.6-r2v"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese. " + "Use identifiers such as `character1` and `character2` to refer to the reference characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["character1", "character2", "character3"], + min=1, + ), + ), + IO.Combo.Input( + "size", + options=[ + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str, + reference_videos: IO.Autogrow.Type, + size: str, + duration: int, + seed: int, + shot_type: str, + watermark: bool, + ): + reference_video_urls = [] + for i in reference_videos: + validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) + for i in reference_videos: + reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + width, height = RES_IN_PARENS.search(size).groups() + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Reference2VideoTaskCreationRequest( + model=model, + input=Reference2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls + ), + parameters=Reference2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + shot_type=shot_type, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=6, + max_poll_attempts=280, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension): WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, + WanReferenceVideoApi, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b8d33f4d1..f1ed7fe9c 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -119,7 +119,7 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" + filename = f"{uuid.uuid4()}.{container.value.lower()}" # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() From b7d7cc1d496afe3c82279eec74c4d47399aab8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:39:06 -0800 Subject: [PATCH 073/181] Fix fp8 fast issue. (#11688) --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8f9fdce36..cd536e22d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -427,12 +427,12 @@ def fp8_linear(self, input): input = torch.clamp(input, min=-448, max=448, out=input) input_fp8 = input.to(dtype).contiguous() layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) - quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) + quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) From fc0cb10bcbee6e73ed3caf34c27f7bde4559a07f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 04:07:31 -0500 Subject: [PATCH 074/181] ComfyUI v0.8.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1ed60fe5c..750673f08 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/pyproject.toml b/pyproject.toml index a7d159be9..951c2c978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.7.0" +version = "0.8.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From c0c9720d77774ed2c87981da87189fe1c14a57fa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:48:28 -0800 Subject: [PATCH 075/181] Fix stable release workflow not pulling latest comfy kitchen. (#11695) --- .github/workflows/stable-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 28484a9d1..f501b7b31 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -117,7 +117,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt From 3cd7b32f1b7e7e90395cefe7d9f9b1f89276d8ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:15:14 -0800 Subject: [PATCH 076/181] Support gemma 12B with quant weights. (#11696) --- comfy/text_encoders/lt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e5964e42b..130ebaeae 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -36,10 +36,10 @@ class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class Gemma3_12BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -119,12 +119,12 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) -def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) From 48e5ea1dfd23a9cb5d118d7af661b026d66743bc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 15:39:20 -0800 Subject: [PATCH 077/181] model_patcher: Remove confusing load stat (#11710) If the loader passes 1e32 as the usable memory size, it means force the full load. This happens with CPU loads and a few other misc cases. Removing the confusing number and just leave the other details. --- comfy/model_patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..4528814ad 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -790,11 +790,12 @@ class ModelPatcher: for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) From 1c705f7bfb0fb59f6213dfb85ec5d5dc2ce4300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:39:59 +0200 Subject: [PATCH 078/181] Add device selection for LTXAVTextEncoderLoader (#11700) --- comfy_extras/nodes_lt_audio.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 26b0160d2..1966fd1bf 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode): io.Combo.Input( "ckpt_name", options=folder_paths.get_filename_list("checkpoints"), + ), + io.Combo.Input( + "device", + options=["default", "cpu"], ) ], outputs=[io.Clip.Output()], @@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + model_options = {} + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) From 34751fe9f9ade0c715768202c19211dc0c72e760 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:12:15 -0800 Subject: [PATCH 079/181] Lower ltxv text encoder vram use. (#11713) --- comfy/text_encoders/lt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 130ebaeae..dc0694e0e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module): out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device + if comfy.model_management.should_use_bf16(self.execution_device): + out = out.to(device=self.execution_device, dtype=torch.bfloat16) out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) + out = out.float() out_vid = self.video_embeddings_connector(out)[0] out_audio = self.audio_embeddings_connector(out)[0] out = torch.concat((out_vid, out_audio), dim=-1) From 007b87e7ac29e55ce0ad2c436f5ae68f3a078080 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:48:47 -0800 Subject: [PATCH 080/181] Bump required comfy-kitchen version. (#11714) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bc8346bcf..13e95afa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.3 +comfy-kitchen>=0.2.5 #non essential dependencies: kornia>=0.7.1 From 3cd19e99c10a25cf6e6b51b82e3c16c501733b8c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:04:56 -0800 Subject: [PATCH 081/181] Increase ltxav mem estimation by a bit. (#11715) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ee9a79001..d44c0bc37 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -845,7 +845,7 @@ class LTXAV(LTXV): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = 0.055 # TODO + self.memory_usage_factor = 0.061 # TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.LTXAV(self, device=device) From 25bc1b5b57d61930d6ab60d8cf7e9241d26e4fe9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:11:22 -0800 Subject: [PATCH 082/181] Add memory estimation function to ltxav text encoder. (#11716) --- comfy/sd.py | 11 +++++++---- comfy/text_encoders/lt.py | 8 ++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32157e18b..efde3839c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -218,7 +218,7 @@ class CLIP: if unprojected: self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) @@ -266,7 +266,7 @@ class CLIP: if return_pooled == "unprojected": self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] @@ -299,8 +299,11 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip - def load_model(self): - model_management.load_model_gpu(self.patcher) + def load_model(self, tokens={}): + memory_used = 0 + if hasattr(self.cond_stage_model, "memory_estimation_function"): + memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) return self.patcher def get_key_patches(self): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index dc0694e0e..776e25e97 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -121,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) + def memory_estimation_function(self, token_weight_pairs, device=None): + constant = 6.0 + if comfy.model_management.should_use_bf16(device): + constant /= 2.0 + + token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): From b6c79a648a013f477f514f61580d1a06220b15eb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:01:16 -0800 Subject: [PATCH 083/181] ops: Fix offloading with FP8MM performance (#11697) This logic was checking comfy_cast_weights, and going straight to to the forward_comfy_cast_weights implementation without attempting to downscale input to fp8 in the event comfy_cast_weights is set. The main reason comfy_cast_weights would be set would be for async offload, which is not a good reason to nix FP8MM. So instead, and together the underlying exclusions for FP8MM which are: * having a weight_function (usually LowVramPatch) * force_cast_weights (compute dtype override) * the weight is not Quantized * the input is already quantized * the model or layer has MM explictily disabled. If you get past all of those exclusions, quantize the input tensor. Then hand the new input, quantized or not off to forward_comfy_cast_weights to handle it. If the weight is offloaded but input is quantized you will get an offloaded MM8. --- comfy/model_patcher.py | 1 + comfy/ops.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4528814ad..f6b80a40f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -718,6 +718,7 @@ class ModelPatcher: continue cast_weight = self.force_cast_weights + m.comfy_force_cast_weights = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] diff --git a/comfy/ops.py b/comfy/ops.py index cd536e22d..8156c42ff 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec run_every_op() input_shape = input.shape - tensor_3d = input.ndim == 3 - - if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) + reshaped_3d = False if (getattr(self, 'layout_type', None) is not None and - not isinstance(input, QuantizedTensor)): + not isinstance(input, QuantizedTensor) and not self._full_precision_mm and + not getattr(self, 'comfy_force_cast_weights', False) and + len(self.weight_function) == 0 and len(self.bias_function) == 0): # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - if tensor_3d: - input = input.reshape(-1, input_shape[2]) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - if input.ndim != 2: - # Fall back to comfy_cast_weights for non-2D tensors - return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - # dtype is now implicit in the layout class - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) - - output = self._forward(input, self.weight, self.bias) + output = self.forward_comfy_cast_weights(input) # Reshape output back to 3D if input was 3D - if tensor_3d: + if reshaped_3d: output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) return output From 21e842508733809354a7b04944b2995ed1169370 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:07:26 -0800 Subject: [PATCH 084/181] Add warning for old pytorch. (#11718) --- comfy/quant_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 5a17bc6f5..8324be42a 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -19,6 +19,7 @@ try: cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) if cuda_version < (13,): ck.registry.disable("cuda") + logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") ck.registry.disable("triton") for k, v in ck.list_backends().items(): From fcd9a236b091bd4e77b177134ddfcf7d7dbd71fd Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 8 Jan 2026 10:22:23 +0800 Subject: [PATCH 085/181] Update template to 0.7.69 (#11719) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 13e95afa0..49567ad61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.67 +comfyui-workflow-templates==0.7.69 comfyui-embedded-docs==0.3.1 torch torchsde From ac12f77bed7bbbaf20289533bf7c0bff275e4a41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 22:10:08 -0500 Subject: [PATCH 086/181] ComfyUI version v0.8.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 750673f08..4eb6070fe 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.0" +__version__ = "0.8.1" diff --git a/pyproject.toml b/pyproject.toml index 951c2c978..0037abd6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.0" +version = "0.8.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 50d6e1caf401bf72dca1e9df7e194e722e1bd98b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 20:07:05 -0800 Subject: [PATCH 087/181] Tweak ltxv vae mem estimation. (#11722) --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index efde3839c..5a7221620 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,8 +479,8 @@ class VAE: self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 - self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) From 2e9d51680a90bca9cc375ba7767f7bf3ed27d563 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 23:50:02 -0500 Subject: [PATCH 088/181] ComfyUI version v0.8.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4eb6070fe..df82ed4fc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.1" +__version__ = "0.8.2" diff --git a/pyproject.toml b/pyproject.toml index 0037abd6c..49f1a03fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.1" +version = "0.8.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From a60b7b86c54ea1498e9c5a5c3d6018c0714654d9 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Wed, 7 Jan 2026 21:41:57 -0800 Subject: [PATCH 089/181] Revert "Force sequential execution in CI test jobs (#11687)" (#11725) This reverts commit ce0000c4f2a7dba12324585dddb784b43e3cd3d0. --- .github/workflows/test-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 63df2dc3a..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,7 +20,6 @@ jobs: test-stable: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -75,7 +74,6 @@ jobs: test-unix-nightly: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux] From 5943fbf457d78becbb924a74780e0efc68505a17 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Fri, 9 Jan 2026 01:15:42 +0900 Subject: [PATCH 090/181] bump comfyui_manager version to the 4.0.5 (#11732) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 6585b0c19..bea6d4927 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.4 +comfyui_manager==4.0.5 From 0f11869d55c7a459371b8114b1345e55a0274723 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:16:58 -0800 Subject: [PATCH 091/181] Better detection if AMD torch compiled with efficient attention. (#11745) --- comfy/model_management.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 928282092..e5de4a5b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,7 +22,6 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys -import importlib import platform import weakref import gc @@ -349,10 +348,22 @@ try: except: rocm_version = (6, -1) + def aotriton_supported(gpu_arch): + path = torch.__path__[0] + path = os.path.join(os.path.join(path, "lib"), "aotriton.images") + gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path)))) + if gpu_arch in gfx: + return True + if "{}x".format(gpu_arch[:-1]) in gfx: + return True + if "{}xx".format(gpu_arch[:-2]) in gfx: + return True + return False + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True From 1a206564487d672561d83ce3eb007517bf018995 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:23:59 -0800 Subject: [PATCH 092/181] Fix import issue. (#11746) --- comfy/ldm/hunyuan_video/upsampler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index d9e76922f..51b6d1da8 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,8 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management -import model_patcher +import comfy.model_management +import comfy.model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): @@ -103,13 +103,13 @@ UPSAMPLERS = { class HunyuanVideo15SRModel(): def __init__(self, model_type, config): - self.load_device = model_management.vae_device() - offload_device = model_management.vae_offload_device() - self.dtype = model_management.vae_dtype(self.load_device) + self.load_device = comfy.model_management.vae_device() + offload_device = comfy.model_management.vae_offload_device() + self.dtype = comfy.model_management.vae_dtype(self.load_device) self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=True) @@ -118,5 +118,5 @@ class HunyuanVideo15SRModel(): return self.model.state_dict() def resample_latent(self, latent): - model_management.load_model_gpu(self.patcher) + comfy.model_management.load_model_gpu(self.patcher) return self.model(latent.to(self.load_device)) From 027042db6811c875562296f0a6b797c89d59e426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 9 Jan 2026 05:14:06 +0200 Subject: [PATCH 093/181] Add node: JoinAudioChannels (#11728) --- comfy_extras/nodes_audio.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 94ad5e8a8..15b3aa401 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode): separate = execute # TODO: remove +class JoinAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="JoinAudioChannels", + display_name="Join Audio Channels", + description="Joins left and right mono audio channels into a stereo audio.", + category="audio", + inputs=[ + IO.Audio.Input("audio_left"), + IO.Audio.Input("audio_right"), + ], + outputs=[ + IO.Audio.Output(display_name="audio"), + ], + ) + + @classmethod + def execute(cls, audio_left, audio_right) -> IO.NodeOutput: + waveform_left = audio_left["waveform"] + sample_rate_left = audio_left["sample_rate"] + waveform_right = audio_right["waveform"] + sample_rate_right = audio_right["sample_rate"] + + if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1: + raise ValueError("AudioJoin: Both input audios must be mono.") + + # Handle different sample rates by resampling to the higher rate + waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates( + waveform_left, sample_rate_left, waveform_right, sample_rate_right + ) + + # Handle different lengths by trimming to the shorter length + length_left = waveform_left.shape[-1] + length_right = waveform_right.shape[-1] + + if length_left != length_right: + min_length = min(length_left, length_right) + if length_left > min_length: + logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.") + waveform_left = waveform_left[..., :min_length] + if length_right > min_length: + logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.") + waveform_right = waveform_right[..., :min_length] + + # Join the channels into stereo + left_channel = waveform_left[..., 0:1, :] + right_channel = waveform_right[..., 0:1, :] + stereo_waveform = torch.cat([left_channel, right_channel], dim=1) + + return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate}) + def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): if sample_rate_1 != sample_rate_2: @@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension): RecordAudio, TrimAudioDuration, SplitAudioChannels, + JoinAudioChannels, AudioConcat, AudioMerge, AudioAdjustVolume, From b48d6a83d4f7012a1b6f6f41e66b0ac3f3253b8a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:15:50 -0800 Subject: [PATCH 094/181] Fix csp error in frontend when forcing offline. (#11749) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 70c8b5e3b..4db3347cb 100644 --- a/server.py +++ b/server.py @@ -184,7 +184,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';" return response return block_external_middleware From 114fc73685129bf4e8ddced432247fe67dc6fbff Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Fri, 9 Jan 2026 12:16:15 +0900 Subject: [PATCH 095/181] Bump comfyui-frontend-package to 1.36.13 (#11645) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 49567ad61..7686a5f8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.35.9 +comfyui-frontend-package==1.36.13 comfyui-workflow-templates==0.7.69 comfyui-embedded-docs==0.3.1 torch From 1dc3da631423b776669a6a9128bb1aeaf5592c55 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Jan 2026 19:21:51 -0800 Subject: [PATCH 096/181] Add most basic Asset support for models (#11315) * Brought over minimal elements from PR 10045 to reproduce seed_assets and register_assets_system without adding anything to the DB or server routes yet, for now making everything sync (can introduce async once everything is cleaned up and brought over) * Added db script to insert assets stuff, cleaned up some code; assets (models) now get added/rescanned * Added support for 5 http endpoints for assets * Replaced Optional with | None in schemas_in.py and schemas_out.py * Remove two routes that will not be relevant yet in this PR: HEAD /api/assets/hash/ and PUT /api/assets//preview * Remove some functions the two deleted endpoints were using * Don't show assets scan message upon calling /object_info endpoint * removed unsued import to satisfy ruff * Simplified hashing function tpye hint and _hash_file_obj * Satisfied ruff --- alembic_db/versions/0001_assets.py | 174 +++++++++++++++++++ app/assets/api/routes.py | 102 +++++++++++ app/assets/api/schemas_in.py | 94 ++++++++++ app/assets/api/schemas_out.py | 60 +++++++ app/assets/database/bulk_ops.py | 188 ++++++++++++++++++++ app/assets/database/models.py | 233 +++++++++++++++++++++++++ app/assets/database/queries.py | 267 +++++++++++++++++++++++++++++ app/assets/database/tags.py | 62 +++++++ app/assets/hashing.py | 75 ++++++++ app/assets/helpers.py | 216 +++++++++++++++++++++++ app/assets/manager.py | 123 +++++++++++++ app/assets/scanner.py | 229 +++++++++++++++++++++++++ app/database/models.py | 25 ++- comfy/cli_args.py | 1 + main.py | 3 + server.py | 4 + 16 files changed, 1847 insertions(+), 9 deletions(-) create mode 100644 alembic_db/versions/0001_assets.py create mode 100644 app/assets/api/routes.py create mode 100644 app/assets/api/schemas_in.py create mode 100644 app/assets/api/schemas_out.py create mode 100644 app/assets/database/bulk_ops.py create mode 100644 app/assets/database/models.py create mode 100644 app/assets/database/queries.py create mode 100644 app/assets/database/tags.py create mode 100644 app/assets/hashing.py create mode 100644 app/assets/helpers.py create mode 100644 app/assets/manager.py create mode 100644 app/assets/scanner.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000..1e10b94dc --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,174 @@ +""" +Initial assets schema +Revision ID: 0001_assets +Revises: None +Create Date: 2025-12-10 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity + op.create_table( + "assets", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + + # ASSETS_INFO: user-visible references + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=512), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_CACHE_STATE: N:1 local cache rows per Asset + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary + tags_table = sa.table( + "tags", + sa.column("name", sa.String(length=512)), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, + + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + + {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py new file mode 100644 index 000000000..30e87a898 --- /dev/null +++ b/app/assets/api/routes.py @@ -0,0 +1,102 @@ +import logging +import uuid +from aiohttp import web + +from pydantic import ValidationError + +import app.assets.manager as manager +from app import user_manager +from app.assets.api import schemas_in +from app.assets.helpers import get_query_dict + +ROUTES = web.RouteTableDef() +USER_MANAGER: user_manager.UserManager | None = None + +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global USER_MANAGER + USER_MANAGER = user_manager_instance + app.add_routes(ROUTES) + +def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + + +def _validation_error_response(code: str, ve: ValidationError) -> web.Response: + return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + """ + GET request to list assets. + """ + query_dict = get_query_dict(request) + try: + q = schemas_in.ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) + + payload = manager.list_assets( + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=q.sort, + order=q.order, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(payload.model_dump(mode="json")) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +async def get_asset(request: web.Request) -> web.Response: + """ + GET request to get an asset's info as JSON. + """ + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + result = manager.get_asset( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as e: + return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + except Exception: + logging.exception( + "get_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + """ + GET request to list all tags based on query parameters. + """ + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as e: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, + status=400, + ) + + result = manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(result.model_dump(mode="json")) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py new file mode 100644 index 000000000..200b41aef --- /dev/null +++ b/app/assets/api/schemas_in.py @@ -0,0 +1,94 @@ +import json +import uuid +from typing import Any, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, +) + + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: dict[str, Any] | None = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: str | None = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: str | None) -> str | None: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class SetPreviewBody(BaseModel): + """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" + preview_id: str | None = None + + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py new file mode 100644 index 000000000..9f8184f20 --- /dev/null +++ b/app/assets/api/schemas_out.py @@ -0,0 +1,60 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + preview_url: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetDetail(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + preview_id: str | None = None + created_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py new file mode 100644 index 000000000..9352cd65d --- /dev/null +++ b/app/assets/database/bulk_ops.py @@ -0,0 +1,188 @@ +import os +import uuid +import sqlalchemy +from typing import Iterable +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import utcnow +from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta + +MAX_BIND_PARAMS = 800 + +def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: + if not rows: + return [] + rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i:i + rows_per_stmt] + +def _iter_chunks(seq, n: int): + for i in range(0, len(seq), n): + yield seq[i:i + n] + +def _rows_per_stmt(cols: int) -> int: + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def seed_from_paths_batch( + session: Session, + *, + specs: list[dict], + owner_id: str = "", +) -> dict: + """Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + """ + if not specs: + return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} + + now = utcnow() + asset_rows: list[dict] = [] + state_rows: list[dict] = [] + path_to_asset: dict[str, str] = {} + asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row + path_list: list[str] = [] + + for sp in specs: + ap = os.path.abspath(sp["abs_path"]) + aid = str(uuid.uuid4()) + iid = str(uuid.uuid4()) + path_list.append(ap) + path_to_asset[ap] = aid + + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) + asset_to_info[aid] = { + "id": iid, + "owner_id": owner_id, + "name": sp["info_name"], + "asset_id": aid, + "preview_id": None, + "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + "_tags": sp["tags"], + "_filename": sp["fname"], + } + + # insert all seed Assets (hash=NULL) + ins_asset = sqlite.insert(Asset) + for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): + session.execute(ins_asset, chunk) + + # try to claim AssetCacheState (file_path) + winners_by_path: set[str] = set() + ins_state = ( + sqlite.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): + winners_by_path.update((session.execute(ins_state, chunk)).scalars().all()) + + all_paths_set = set(path_list) + losers_by_path = all_paths_set - winners_by_path + lost_assets = [path_to_asset[p] for p in losers_by_path] + if lost_assets: # losers get their Asset removed + for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): + session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) + + if not winners_by_path: + return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + + # insert AssetInfo only for winners + winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] + ins_info = ( + sqlite.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): + inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all()) + + # build and insert tag + meta rows for the AssetInfo + tag_rows: list[dict] = [] + meta_rows: list[dict] = [] + if inserted_info_ids: + for row in winner_info_rows: + iid = row["id"] + if iid not in inserted_info_ids: + continue + for t in row["_tags"]: + tag_rows.append({ + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + }) + if row["_filename"]: + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) + return { + "inserted_infos": len(inserted_info_ids), + "won_states": len(winners_by_path), + "lost_states": len(losers_by_path), + } + + +def bulk_insert_tags_and_meta( + session: Session, + *, + tag_rows: list[dict], + meta_rows: list[dict], + max_bind_params: int, +) -> None: + """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. + - tag_rows keys: asset_info_id, tag_name, origin, added_at + - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + if tag_rows: + ins_links = ( + sqlite.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): + session.execute(ins_links, chunk) + if meta_rows: + ins_meta = ( + sqlite.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py new file mode 100644 index 000000000..3cd28f68b --- /dev/null +++ b/app/assets/database/models.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import uuid +from datetime import datetime + +from typing import Any +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship + +from app.assets.helpers import utcnow +from app.database.models import to_dict, Base + + +class Asset(Base): + __tablename__ = "assets" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[str | None] = mapped_column(String(256), nullable=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + infos: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + preview_of: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], + viewonly=True, + ) + + cache_states: Mapped[list[AssetCacheState]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), + Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) + file_path: Mapped[str] = mapped_column(Text, nullable=False) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + asset: Mapped[Asset] = relationship(back_populates="cache_states") + + __table_args__ = ( + Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_id", "asset_id"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_id], + lazy="selectin", + ) + preview_asset: Mapped[Asset | None] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_id], + ) + + metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list[Tag]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="selectin", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_id", "asset_id"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True) + val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(512), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list[AssetInfo]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py new file mode 100644 index 000000000..0824c0c2f --- /dev/null +++ b/app/assets/database/queries.py @@ -0,0 +1,267 @@ +import sqlalchemy as sa +from collections import defaultdict +from sqlalchemy import select, exists, func +from sqlalchemy.orm import Session, contains_eager, noload +from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.helpers import escape_like_prefix, normalize_tags +from typing import Sequence + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + +def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None: + return session.get(AssetInfo, asset_info_id) + +def list_asset_infos_page( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((session.execute(count_stmt)).scalar_one() or 0) + + infos = (session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + +def fetch_asset_info_asset_and_tags( + session: Session, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset, list[str]] | None: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py new file mode 100644 index 000000000..3ab6497c2 --- /dev/null +++ b/app/assets/database/tags.py @@ -0,0 +1,62 @@ +from typing import Iterable + +import sqlalchemy +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import normalize_tags, utcnow +from app.assets.database.models import Tag, AssetInfoTag, AssetInfo + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + return session.execute(ins) + +def add_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sqlalchemy.select( + AssetInfo.id.label("asset_info_id"), + sqlalchemy.literal("missing").label("tag_name"), + sqlalchemy.literal(origin).label("origin"), + sqlalchemy.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sqlalchemy.not_( + sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + ) + ) + ) + session.execute( + sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sqlalchemy.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py new file mode 100644 index 000000000..4b72084b9 --- /dev/null +++ b/app/assets/hashing.py @@ -0,0 +1,75 @@ +from blake3 import blake3 +from typing import IO +import os +import asyncio + + +DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB + +# NOTE: this allows hashing different representations of a file-like object +def blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """ + Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + # duck typing to check if input is a file-like object + if hasattr(fp, "read"): + return _hash_file_obj(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + +async def blake3_hash_async( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + return await asyncio.to_thread(_worker) + + +def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: + """ + Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + # in case file object is already open and not at the beginning, track so can be restored after hashing + orig_pos = file_obj.tell() + + try: + # seek to the beginning before reading + if orig_pos != 0: + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + # restore original position in file object, if needed + if orig_pos != 0: + file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py new file mode 100644 index 000000000..6755d0e56 --- /dev/null +++ b/app/assets/helpers.py @@ -0,0 +1,216 @@ +import contextlib +import os +from aiohttp import web +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, Any + +import folder_paths + + +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") + +def get_query_dict(request: web.Request) -> dict[str, Any]: + """ + Gets a dictionary of query parameters from the request. + + 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) + for key in request.query.keys() + } + return query_dict + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + +def prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape + +def fast_asset_file_check( + *, + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + if mtime_db is None: + return False + actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + if int(mtime_db) != int(actual_mtime_ns): + return False + sz = int(size_db or 0) + if sz > 0: + return int(stat_result.st_size) == sz + return True + +def utcnow() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. + + We trust `folder_paths.folder_names_and_paths` and include a category if + *any* of its base paths lies under the Comfy `models_dir`. + """ + targets: list[tuple[str, list[str]]] = [] + models_root = os.path.abspath(folder_paths.models_dir) + for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + return targets + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. + """ + try: + root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: + """Given an absolute or relative file path, determine which root category the path belongs to: + - 'input' if the file resides under `folder_paths.get_input_directory()` + - 'output' if the file resides under `folder_paths.get_output_directory()` + - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` + + Returns: + (root_category, relative_path_inside_that_root) + For 'models', the relative path is prefixed with the category name: + e.g. ('models', 'vae/test/sub/ae.safetensors') + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + fp_abs = os.path.abspath(file_path) + + def _is_within(child: str, parent: str) -> bool: + try: + return os.path.commonpath([child, parent]) == parent + except Exception: + return False + + def _rel(child: str, parent: str) -> str: + return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _is_within(fp_abs, input_base): + return "input", _rel(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _is_within(fp_abs, output_base): + return "output", _rel(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return a tuple (name, tags) derived from a filesystem path. + + Semantics: + - Root category is determined by `get_relative_to_root_category_path_of_asset`. + - The returned `name` is the base filename with extension from the relative path. + - The returned `tags` are: + [root_category] + parent folders of the relative path (in order) + For 'models', this means: + file '/.../ModelsDir/vae/test_tag/ae.safetensors' + -> root_category='models', some_path='vae/test_tag/ae.safetensors' + -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) + p = Path(some_path) + parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + +def normalize_tags(tags: list[str] | None) -> list[str]: + """ + Normalize a list of tags by: + - Stripping whitespace and converting to lowercase. + - Removing duplicates. + """ + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out diff --git a/app/assets/manager.py b/app/assets/manager.py new file mode 100644 index 000000000..6425e7aa2 --- /dev/null +++ b/app/assets/manager.py @@ -0,0 +1,123 @@ +from typing import Sequence + +from app.database.db import create_session +from app.assets.api import schemas_out +from app.assets.database.queries import ( + asset_exists_by_hash, + fetch_asset_info_asset_and_tags, + list_asset_infos_page, + list_tags_with_usage, +) + + +def _safe_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + +def list_assets( + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", + owner_id: str = "", +) -> schemas_out.AssetsList: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + with create_session() as session: + infos, tag_map, total = list_asset_infos_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + summaries: list[schemas_out.AssetSummary] = [] + for info in infos: + asset = info.asset + tags = tag_map.get(info.id, []) + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + preview_url=f"/api/assets/{info.id}/content", + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) + ) + + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) + +def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: + with create_session() as session: + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise ValueError(f"AssetInfo {asset_info_id} not found") + info, asset, tag_names = res + preview_id = info.preview_id + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py new file mode 100644 index 000000000..a16e41d94 --- /dev/null +++ b/app/assets/scanner.py @@ -0,0 +1,229 @@ +import contextlib +import time +import logging +import os +import sqlalchemy + +import folder_paths +from app.database.db import create_session, dependencies_available +from app.assets.helpers import ( + collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, + list_tree,prefixes_for_root, escape_like_prefix, + RootType +) +from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id +from app.assets.database.bulk_ops import seed_from_paths_batch +from app.assets.database.models import Asset, AssetCacheState, AssetInfo + + +def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: + """ + Scan the given roots and seed the assets into the database. + """ + if not dependencies_available(): + if enable_logging: + logging.warning("Database dependencies not available, skipping assets scan") + return + t_start = time.perf_counter() + created = 0 + skipped_existing = 0 + paths: list[str] = [] + try: + existing_paths: set[str] = set() + for r in roots: + try: + survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) + if survivors: + existing_paths.update(survivors) + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", r, e) + + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + specs: list[dict] = [] + tag_pool: set[str] = set() + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped_existing += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=False) + except OSError: + continue + # skip empty files + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(abs_p), + } + ) + for t in tags: + tag_pool.add(t) + # if no file specs, nothing to do + if not specs: + return + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + + result = seed_from_paths_batch(sess, specs=specs, owner_id="") + created += result["inserted_infos"] + sess.commit() + finally: + if enable_logging: + logging.info( + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", + roots, + time.perf_counter() - t_start, + created, + skipped_existing, + len(paths), + ) + + +def _fast_db_consistency_pass( + root: RootType, + *, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> set[str] | None: + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return set() if collect_existing_paths else None + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + with create_session() as sess: + rows = ( + sess.execute( + sqlalchemy.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sqlalchemy.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).all() + + by_asset: dict[str, dict] = {} + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + fast_ok = False + try: + exists = True + fast_ok = fast_asset_file_check( + mtime_db=mtime_db, + size_db=acc["size_db"], + stat_result=os.stat(fp, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except OSError: + exists = False + + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) + + to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) + + if a_hash is None: + if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = sess.get(Asset, aid) + if asset: + sess.delete(asset) + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: + with contextlib.suppress(Exception): + remove_missing_tag_for_asset_id(sess, asset_id=aid) + elif update_missing_tags: + with contextlib.suppress(Exception): + add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) + if to_set_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) + if to_clear_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) + sess.commit() + return survivors if collect_existing_paths else None diff --git a/app/database/models.py b/app/database/models.py index 6facfb8f2..e7572677a 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,14 +1,21 @@ -from sqlalchemy.orm import declarative_base +from typing import Any +from datetime import datetime +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() +class Base(DeclarativeBase): + pass - -def to_dict(obj): +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out # TODO: Define models here diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dae9a895d..1716c3de7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -231,6 +231,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/main.py b/main.py index 0e07a95da..37b06c1fa 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import folder_paths import time from comfy.cli_args import args from app.logger import setup_logger +from app.assets.scanner import seed_assets import itertools import utils.extra_config import logging @@ -324,6 +325,8 @@ def setup_database(): from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() + if not args.disable_assets_autoscan: + seed_assets(["models"], enable_logging=True) except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") diff --git a/server.py b/server.py index 4db3347cb..da2baefd4 100644 --- a/server.py +++ b/server.py @@ -33,6 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal +from app.assets.scanner import seed_assets +from app.assets.api.routes import register_assets_system from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -235,6 +237,7 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -683,6 +686,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): + seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: From 6207f86c18d2cf2d70ab059987b62d4b38466e77 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 8 Jan 2026 20:34:48 -0800 Subject: [PATCH 097/181] Fix VAEEncodeForInpaint to support WAN VAE tuple downscale_ratio (#11572) Use vae.spacial_compression_encode() instead of directly accessing downscale_ratio to handle both standard VAEs (int) and WAN VAEs (tuple). Addresses reviewer feedback on PR #11259. Co-authored-by: ChrisFab16 --- nodes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 56b74ebe3..1aa391f4a 100644 --- a/nodes.py +++ b/nodes.py @@ -378,14 +378,15 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio - y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + downscale_ratio = vae.spacial_compression_encode() + x = (pixels.shape[1] // downscale_ratio) * downscale_ratio + y = (pixels.shape[2] // downscale_ratio) * downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 - y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 + x_offset = (pixels.shape[1] % downscale_ratio) // 2 + y_offset = (pixels.shape[2] % downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] From 4609fcd26081156eef921bd9f43726f670ee6f51 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 9 Jan 2026 00:31:19 -0500 Subject: [PATCH 098/181] add node - image compare (#11343) --- comfy_api/latest/_io.py | 13 +++++++ comfy_extras/nodes_image_compare.py | 53 +++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 67 insertions(+) create mode 100644 comfy_extras/nodes_image_compare.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 764fa8b2b..50143ff53 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1113,6 +1113,18 @@ class DynamicSlot(ComfyTypeI): out_dict[input_type][finalized_id] = value out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) +@comfytype(io_type="IMAGECOMPARE") +class ImageCompare(ComfyTypeI): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True): + super().__init__(id, display_name, optional, tooltip, None, None, socketless) + + def as_dict(self): + return super().as_dict() + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -1958,4 +1970,5 @@ __all__ = [ "add_to_dict_v1", "add_to_dict_v3", "V3Data", + "ImageCompare", ] diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py new file mode 100644 index 000000000..8e9f809e6 --- /dev/null +++ b/comfy_extras/nodes_image_compare.py @@ -0,0 +1,53 @@ +import nodes + +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension + + +class ImageCompare(IO.ComfyNode): + """Compares two images with a slider interface.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCompare", + display_name="Image Compare", + description="Compares two images side by side with a slider.", + category="image", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.Image.Input("image_a", optional=True), + IO.Image.Input("image_b", optional=True), + IO.ImageCompare.Input("compare_view"), + ], + outputs=[], + ) + + @classmethod + def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput: + result = {"a_images": [], "b_images": []} + + preview_node = nodes.PreviewImage() + + if image_a is not None and len(image_a) > 0: + saved = preview_node.save_images(image_a, "comfy.compare.a") + result["a_images"] = saved["ui"]["images"] + + if image_b is not None and len(image_b) > 0: + saved = preview_node.save_images(image_b, "comfy.compare.b") + result["b_images"] = saved["ui"]["images"] + + return IO.NodeOutput(ui=result) + + +class ImageCompareExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCompare, + ] + + +async def comfy_entrypoint() -> ImageCompareExtension: + return ImageCompareExtension() diff --git a/nodes.py b/nodes.py index 1aa391f4a..5a9d42d4a 100644 --- a/nodes.py +++ b/nodes.py @@ -2370,6 +2370,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_image_compare.py", ] import_failed = [] From 04c49a29b493f3f9037b83cec45f6369b5c4816b Mon Sep 17 00:00:00 2001 From: ric-yu Date: Thu, 8 Jan 2026 21:57:36 -0800 Subject: [PATCH 099/181] feat: add cancelled filter to /jobs (#11680) --- comfy_execution/jobs.py | 31 +++++++++++++++++------------ tests/execution/test_jobs.py | 38 +++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 59fb49357..97fd803b8 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -14,8 +14,9 @@ class JobStatus: IN_PROGRESS = 'in_progress' COMPLETED = 'completed' FAILED = 'failed' + CANCELLED = 'cancelled' - ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] # Media types that can be previewed in the frontend @@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: status_info = history_item.get('status', {}) status_str = status_info.get('status_str') if status_info else None - if status_str == 'success': - status = JobStatus.COMPLETED - elif status_str == 'error': - status = JobStatus.FAILED - else: - status = JobStatus.COMPLETED outputs = history_item.get('outputs', {}) outputs_count, preview_output = get_outputs_summary(outputs) @@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_error = None execution_start_time = None execution_end_time = None + was_interrupted = False if status_info: messages = status_info.get('messages', []) for entry in messages: @@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_end_time = event_data.get('timestamp') if event_name == 'execution_error': execution_error = event_data + elif event_name == 'execution_interrupted': + was_interrupted = True + + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED + else: + status = JobStatus.COMPLETED job = prune_dict({ 'id': prompt_id, @@ -268,13 +273,13 @@ def get_all_jobs( for item in queued: jobs.append(normalize_queue_item(item, JobStatus.PENDING)) - include_completed = JobStatus.COMPLETED in status_filter - include_failed = JobStatus.FAILED in status_filter - if include_completed or include_failed: + history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} + requested_history_statuses = history_statuses & set(status_filter) + if requested_history_statuses: for prompt_id, history_item in history.items(): - is_failed = history_item.get('status', {}).get('status_str') == 'error' - if (is_failed and include_failed) or (not is_failed and include_completed): - jobs.append(normalize_history_item(prompt_id, history_item)) + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 918c8080a..4d2f9ed36 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -19,6 +19,7 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS == 'in_progress' assert JobStatus.COMPLETED == 'completed' assert JobStatus.FAILED == 'failed' + assert JobStatus.CANCELLED == 'cancelled' def test_all_contains_all_statuses(self): """ALL should contain all status values.""" @@ -26,7 +27,8 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS in JobStatus.ALL assert JobStatus.COMPLETED in JobStatus.ALL assert JobStatus.FAILED in JobStatus.ALL - assert len(JobStatus.ALL) == 4 + assert JobStatus.CANCELLED in JobStatus.ALL + assert len(JobStatus.ALL) == 5 class TestIsPreviewable: @@ -336,6 +338,40 @@ class TestNormalizeHistoryItem: assert job['execution_error']['node_type'] == 'KSampler' assert job['execution_error']['exception_message'] == 'CUDA out of memory' + def test_cancelled_job(self): + """Cancelled/interrupted history item should have cancelled status.""" + history_item = { + 'prompt': ( + 5, + 'prompt-cancelled', + {'nodes': {}}, + {'create_time': 1234567890000}, + ['node1'], + ), + 'status': { + 'status_str': 'error', + 'completed': False, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}), + ('execution_interrupted', { + 'prompt_id': 'prompt-cancelled', + 'node_id': '5', + 'node_type': 'KSampler', + 'executed': ['1', '2', '3'], + 'timestamp': 1234567891000, + }) + ] + }, + 'outputs': {}, + } + + job = normalize_history_item('prompt-cancelled', history_item) + assert job['status'] == 'cancelled' + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567891000 + # Cancelled jobs should not have execution_error set + assert 'execution_error' not in job + def test_include_outputs(self): """When include_outputs=True, should include full output data.""" history_item = { From ec0a832acb25fbe53bd4fc25d286a9ee442a3bcf Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Jan 2026 22:49:12 -0800 Subject: [PATCH 100/181] Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755) --- app/assets/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 6755d0e56..08b465b5a 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -81,7 +81,8 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: """ targets: list[tuple[str, list[str]]] = [] models_root = os.path.abspath(folder_paths.models_dir) - for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + for name, values in folder_paths.folder_names_and_paths.items(): + paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): targets.append((name, paths)) return targets From bd0e6825e84606e0706bbb5645e9ea1f4ad8154d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 Jan 2026 11:21:06 -0800 Subject: [PATCH 101/181] Be less strict when loading mixed ops weights. (#11769) --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8156c42ff..1cf22f0cc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -546,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_key = f"{prefix}weight" weight = state_dict.pop(weight_key, None) if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + logging.warning(f"Missing weight for layer {layer_name}") + return manually_loaded_keys = [weight_key] From 4484b93d615059012d3a5ce91d1dbbb0cbaa2d76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:25:56 +0200 Subject: [PATCH 102/181] fix(api-nodes): do not downscale the input image for Topaz Enhance (#11768) --- comfy_api_nodes/nodes_topaz.py | 7 ++++--- comfy_api_nodes/util/upload_helpers.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index b04575ad8..9dc5f45bc 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -2,7 +2,6 @@ import builtins from io import BytesIO import aiohttp -import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str = "", subject_detection: str = "All", face_enhancement: bool = True, @@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") - download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + download_url = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096 + ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index f1ed7fe9c..2535a0884 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -49,6 +49,7 @@ async def upload_images_to_comfyapi( mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, + total_pixels: int = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -63,7 +64,7 @@ async def upload_images_to_comfyapi( for idx in range(num_to_upload): tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label if wait_label and show_batch_index and num_to_upload > 1: From 393d2880ddc6e57c0fa3b878bb50fa2901bd57e6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:59:38 +0200 Subject: [PATCH 103/181] feat(api-nodes): added nodes for Vidu2 (#11760) --- comfy_api_nodes/apis/vidu.py | 41 +++ comfy_api_nodes/nodes_vidu.py | 588 +++++++++++++++++++++++++--------- 2 files changed, 482 insertions(+), 147 deletions(-) create mode 100644 comfy_api_nodes/apis/vidu.py diff --git a/comfy_api_nodes/apis/vidu.py b/comfy_api_nodes/apis/vidu.py new file mode 100644 index 000000000..a9bb6f7ce --- /dev/null +++ b/comfy_api_nodes/apis/vidu.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field + + +class SubjectReference(BaseModel): + id: str = Field(...) + images: list[str] = Field(...) + + +class TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + aspect_ratio: str | None = Field(None) + resolution: str | None = Field(None) + movement_amplitude: str | None = Field(None) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + subjects: list[SubjectReference] | None = Field(None) + bgm: bool | None = Field(None) + audio: bool | None = Field(None) + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: str = Field(...) + created_at: str = Field(...) + code: int | None = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: str = Field(...) + err_code: str | None = Field(None) + progress: float | None = Field(None) + credits: int | None = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 7a679f0d9..9d94ae7ad 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,12 +1,13 @@ -import logging -from enum import Enum -from typing import Literal, Optional, TypeVar - -import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.vidu import ( + SubjectReference, + TaskCreationRequest, + TaskCreationResponse, + TaskResult, + TaskStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_video_output, @@ -17,6 +18,7 @@ from comfy_api_nodes.util import ( validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, + validate_string, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" -R = TypeVar("R") - - -class VideoModelName(str, Enum): - vidu_q1 = "viduq1" - - -class AspectRatio(str, Enum): - r_16_9 = "16:9" - r_9_16 = "9:16" - r_1_1 = "1:1" - - -class Resolution(str, Enum): - r_1080p = "1080p" - - -class MovementAmplitude(str, Enum): - auto = "auto" - small = "small" - medium = "medium" - large = "large" - - -class TaskCreationRequest(BaseModel): - model: VideoModelName = VideoModelName.vidu_q1 - prompt: Optional[str] = Field(None, max_length=1500) - duration: Optional[Literal[5]] = 5 - seed: Optional[int] = Field(0, ge=0, le=2147483647) - aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 - resolution: Optional[Resolution] = Resolution.r_1080p - movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto - images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") - - -class TaskCreationResponse(BaseModel): - task_id: str = Field(...) - state: str = Field(...) - created_at: str = Field(...) - code: Optional[int] = Field(None, description="Error code") - - -class TaskResult(BaseModel): - id: str = Field(..., description="Creation id") - url: str = Field(..., description="The URL of the generated results, valid for one hour") - cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") - - -class TaskStatusResponse(BaseModel): - state: str = Field(...) - err_code: Optional[str] = Field(None) - creations: list[TaskResult] = Field(..., description="Generated results") - - -def get_video_url_from_response(response) -> Optional[str]: - if response.creations: - return response.creations[0].url - return None - - -def get_video_from_response(response) -> TaskResult: - if not response.creations: - error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" - logging.info(error_msg) - raise RuntimeError(error_msg) - logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) - return response.creations[0] - async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, payload: TaskCreationRequest, - estimated_duration: int, -) -> R: - response = await sync_op( +) -> list[TaskResult]: + task_creation_response = await sync_op( cls, endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), response_model=TaskCreationResponse, data=payload, ) - if response.state == "failed": - error_msg = f"Vidu request failed. Code: {response.code}" - logging.error(error_msg) - raise RuntimeError(error_msg) - return await poll_op( + if task_creation_response.state == "failed": + raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}") + response = await poll_op( cls, - ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id), response_model=TaskStatusResponse, status_extractor=lambda r: r.state, - estimated_duration=estimated_duration, + progress_extractor=lambda r: r.progress, + max_poll_attempts=320, ) + if not response.creations: + raise RuntimeError( + f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + ) + return response.creations class ViduTextToVideoNode(IO.ComfyNode): @@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode): node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", category="api node/video/Vidu", - description="Generate video from text prompt", + description="Generate video from a text prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.String.Input( "prompt", multiline=True, @@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -208,7 +137,7 @@ class ViduTextToVideoNode(IO.ComfyNode): if not prompt: raise ValueError("The prompt field is required and cannot be empty.") payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -216,8 +145,8 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduImageToVideoNode(IO.ComfyNode): @@ -230,12 +159,7 @@ class ViduImageToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate video from image and optional prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "image", tooltip="An image to be used as the start frame of the generated video", @@ -270,15 +194,13 @@ class ViduImageToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -298,7 +220,7 @@ class ViduImageToVideoNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, duration: int, seed: int, @@ -309,7 +231,7 @@ class ViduImageToVideoNode(IO.ComfyNode): raise ValueError("Only one input image is allowed.") validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -322,8 +244,8 @@ class ViduImageToVideoNode(IO.ComfyNode): max_images=1, mime_type="image/png", ) - results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduReferenceVideoNode(IO.ComfyNode): @@ -334,14 +256,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", category="api node/video/Vidu", - description="Generate video from multiple images and prompt", + description="Generate video from multiple images and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "images", tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", @@ -374,22 +291,19 @@ class ViduReferenceVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -409,7 +323,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): async def execute( cls, model: str, - images: torch.Tensor, + images: Input.Image, prompt: str, duration: int, seed: int, @@ -426,7 +340,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -440,8 +354,8 @@ class ViduReferenceVideoNode(IO.ComfyNode): max_images=7, mime_type="image/png", ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduStartEndToVideoNode(IO.ComfyNode): @@ -454,12 +368,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "first_frame", tooltip="Start frame", @@ -497,15 +406,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -525,8 +432,8 @@ class ViduStartEndToVideoNode(IO.ComfyNode): async def execute( cls, model: str, - first_frame: torch.Tensor, - end_frame: torch.Tensor, + first_frame: Input.Image, + end_frame: Input.Image, prompt: str, duration: int, seed: int, @@ -535,7 +442,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -546,8 +453,391 @@ class ViduStartEndToVideoNode(IO.ComfyNode): (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2TextToVideoNode", + display_name="Vidu2 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Boolean.Input( + "background_music", + default=False, + tooltip="Whether to add background music to the generated video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + background_music: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + bgm=background_music, + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ImageToVideoNode", + display_name="Vidu2 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + ), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + ), + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ReferenceVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ReferenceVideoNode", + display_name="Vidu2 Reference-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from multiple reference images and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.Autogrow.Input( + "subjects", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_images"), + names=["subject1", "subject2", "subject3"], + min=1, + ), + tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " + "Reference them in prompts via @subject{subject_id}.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="When enabled, the video will include generated speech and background music " + "based on the prompt.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled video will contain generated speech and background music based on the prompt.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), + IO.Combo.Input("resolution", options=["720p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + subjects: IO.Autogrow.Type, + prompt: str, + audio: bool, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + total_images = 0 + for i in subjects: + if get_number_of_images(subjects[i]) > 3: + raise ValueError("Maximum number of images per subject is 3.") + for im in subjects[i]: + total_images += 1 + validate_image_aspect_ratio(im, (1, 4), (4, 1)) + validate_image_dimensions(im, min_width=128, min_height=128) + if total_images > 7: + raise ValueError("Too many reference images; the maximum allowed is 7.") + subjects_param: list[SubjectReference] = [] + for i in subjects: + subjects_param.append( + SubjectReference( + id=i, + images=await upload_images_to_comfyapi( + cls, + subjects[i], + max_images=3, + mime_type="image/png", + wait_label=f"Uploading reference images for {i}", + ), + ), + ) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + audio=audio, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + subjects=subjects_param, + ) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2StartEndToVideoNode", + display_name="Vidu2 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=2, + max=8, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + if get_number_of_images(first_frame) > 1: + raise ValueError("Only one input image is allowed for `first_frame`.") + if get_number_of_images(end_frame) > 1: + raise ValueError("Only one input image is allowed for `end_frame`.") + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduExtension(ComfyExtension): @@ -558,6 +848,10 @@ class ViduExtension(ComfyExtension): ViduImageToVideoNode, ViduReferenceVideoNode, ViduStartEndToVideoNode, + Vidu2TextToVideoNode, + Vidu2ImageToVideoNode, + Vidu2ReferenceVideoNode, + Vidu2StartEndToVideoNode, ] From 153bc524bf9db76d723289f6420f418f63966972 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 10 Jan 2026 14:29:30 +0800 Subject: [PATCH 104/181] chore: update embedded docs to v0.4.0 (#11776) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7686a5f8a..6c1cd86d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.36.13 comfyui-workflow-templates==0.7.69 -comfyui-embedded-docs==0.3.1 +comfyui-embedded-docs==0.4.0 torch torchsde torchvision From dc202a2e51bf7a6cd00e606b2d2941bc223f2ad2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:03:57 -0800 Subject: [PATCH 105/181] Properly save mixed ops. (#11772) --- comfy/ops.py | 26 ++++++++++++------- .../comfy_quant/test_mixed_precision.py | 6 ++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 1cf22f0cc..9c0b54ff4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(key) def state_dict(self, *args, destination=None, prefix="", **kwargs): - sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) - if isinstance(self.weight, QuantizedTensor): - layout_cls = self.weight._layout_cls + if destination is not None: + sd = destination + else: + sd = {} - # Check if it's any FP8 variant (E4M3 or E5M2) - if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): - sd["{}weight_scale".format(prefix)] = self.weight._params.scale - elif layout_cls == "TensorCoreNVFP4Layout": - sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale - sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + if self.bias is not None: + sd["{}bias".format(prefix)] = self.bias + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] quant_conf = {"format": self.quant_format} if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + + input_scale = getattr(self, 'input_scale', None) + if input_scale is not None: + sd["{}input_scale".format(prefix)] = input_scale + else: + sd["{}weight".format(prefix)] = self.weight return sd def _forward(self, input, weight, bias): diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7b2eac940..7c740491d 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict2 = model.state_dict() # Verify layer1.weight is a QuantizedTensor with scale preserved - self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") + self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8))) + self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0) + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) From 6e4b1f9d00306fe14d7ca5adf2c7468d631b23d5 Mon Sep 17 00:00:00 2001 From: DELUXA Date: Sat, 10 Jan 2026 23:51:05 +0200 Subject: [PATCH 106/181] pythorch_attn_by_def_on_gfx1200 (#11793) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e5de4a5b5..9d39be7b2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -368,7 +368,7 @@ try: if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): + if any((a in arch) for a in ["gfx1200", "gfx1201"]): ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 From cd912963f17c9ae00ec12e1869293edb78720831 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 10 Jan 2026 14:31:31 -0800 Subject: [PATCH 107/181] Fix issue with t5 text encoder in fp4. (#11794) --- comfy/model_detection.py | 2 ++ comfy/sd.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0853b3aec..aff5a50b9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["vec_in_dim"] = None + dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"]) + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma diff --git a/comfy/sd.py b/comfy/sd.py index 5a7221620..b689c0dfc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1059,9 +1059,9 @@ def detect_te_model(sd): return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + if weight.shape[0] == 10240: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight.shape[0] == 5120: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD From 2f642d5d9b48ad7cad13bbdd5f8adcf506f565a7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 10 Jan 2026 14:40:42 -0800 Subject: [PATCH 108/181] Fix chroma fp8 te being treated as fp16. (#11795) --- comfy/text_encoders/cosmos.py | 2 +- comfy/text_encoders/genmo.py | 2 +- comfy/text_encoders/pixart_t5.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 448381fa9..f4b40ac68 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return CosmosTEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 5daea8135..2d7a3fbce 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return MochiTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index e5e5f18be..51c6e50c7 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return PixArtTEModel_ From 5cd1113236b0fb032a51bf9d63ba196a2510b0d4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 11 Jan 2026 13:07:11 +0200 Subject: [PATCH 109/181] fix(api-nodes): use a unique name for uploading audio files (#11778) --- comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/util/conversions.py | 4 ++-- comfy_api_nodes/util/upload_helpers.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9c707a339..01d9c34f5 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -567,7 +567,7 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: audio_url = await upload_audio_to_comfyapi( - cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index d64239c86..99c302a2a 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: str | None = None, + *, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -75,7 +75,7 @@ def tensor_to_bytesio( pil_image = tensor_to_pil(image, total_pixels=total_pixels) img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}" return img_binary diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 2535a0884..cea0d1203 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -82,7 +82,6 @@ async def upload_audio_to_comfyapi( container_format: str = "mp4", codec_name: str = "aac", mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", ) -> str: """ Uploads a single audio input to ComfyUI API and returns its download URL. @@ -92,7 +91,7 @@ async def upload_audio_to_comfyapi( waveform: torch.Tensor = audio["waveform"] audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) async def upload_video_to_comfyapi( From c6238047ee1ffd87eade7c3ab5a8e53c11d4ce39 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 11 Jan 2026 18:11:53 -0800 Subject: [PATCH 110/181] Put more details about portable in readme. (#11816) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6d09758c0..e25f3cda7 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp If you have trouble extracting it, right click the file -> properties -> unblock -Update your Nvidia drivers if it doesn't start. +The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. #### Alternative Downloads: @@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: From a3b5d4996abcd906c7c99f15b69fde051afcb4be Mon Sep 17 00:00:00 2001 From: kelseyee <971704395@qq.com> Date: Tue, 13 Jan 2026 04:38:46 +0800 Subject: [PATCH 111/181] Support ModelScope-Trainer DiffSynth lora for Z Image. (#11805) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 2ed0acb9d..e8246bd66 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to if isinstance(model, comfy.model_base.Kandinsky5): for k in sdk: From c881a1d6897d8fee84559a8e3e49b9116efdb959 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:05:54 -0800 Subject: [PATCH 112/181] Support the siglip 2 naflex model as a clip vision model. (#11831) Not useful yet. --- comfy/clip_model.py | 62 +++++++++++++++++++++- comfy/clip_vision.py | 25 ++++++--- comfy/clip_vision_siglip2_base_naflex.json | 14 +++++ 3 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 comfy/clip_vision_siglip2_base_naflex.json diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e88872728..d7d3f994c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,6 +1,7 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +import math def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): image = image[:, :, :, :3] if image.shape[3] > 3 else image @@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) +def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5): + def scale_dim(size, scale): + scaled = math.ceil(size * scale / patch_size) * patch_size + return max(patch_size, int(scaled)) + + # Binary search for optimal scale + lo, hi = eps / 10, 100.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + h, w = scale_dim(oh, mid), scale_dim(ow, mid) + if (h // patch_size) * (w // patch_size) <= max_num_patches: + lo = mid + else: + hi = mid + + return scale_dim(oh, lo), scale_dim(ow, lo) + +def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True): + if size > 0: + return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop) + + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + + b, c, h, w = image.shape + h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches) + + image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True) + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() @@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module): out = self.text_projection(x[2]) return (x[0], x[1], out, x[2]) +def siglip2_pos_embed(embed_weight, embeds, orig_shape): + embed_weight_len = round(embed_weight.shape[0] ** 0.5) + embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len) + embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True) + embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1) + return embeds + embed_weight + +class Siglip2Embeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None): + super().__init__() + self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) + self.patch_size = patch_size + + def forward(self, pixel_values): + b, c, h, w = pixel_values.shape + img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c) + img = img.permute(0, 1, 3, 2, 4, 5) + img = img.reshape(b, img.shape[1] * img.shape[2], -1) + img = self.patch_embedding(img) + return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size)) class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): @@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module): intermediate_activation = config_dict["hidden_act"] model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) - if model_type == "siglip_vision_model": + if model_type in ["siglip2_vision_model"]: + self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations) + else: + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type in ["siglip_vision_model", "siglip2_vision_model"]: self.pre_layrnorm = lambda a: a self.output_layernorm = True else: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d5fc53497..66f2a9d9c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, } @@ -32,9 +33,10 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_type = config.get("model_type", "clip_vision_model") - model_class = IMAGE_ENCODERS.get(model_type) - if model_type == "siglip_vision_model": + self.model_type = config.get("model_type", "clip_vision_model") + self.config = config.copy() + model_class = IMAGE_ENCODERS.get(self.model_type) + if self.model_type == "siglip_vision_model": self.return_all_hidden_states = True else: self.return_all_hidden_states = False @@ -55,7 +57,10 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + if self.model_type == "siglip2_vision_model": + pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float() + else: + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() @@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - if embed_shape == 729: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif embed_shape == 1024: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape + if len(patch_embedding_shape) == 2: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json") + else: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") diff --git a/comfy/clip_vision_siglip2_base_naflex.json b/comfy/clip_vision_siglip2_base_naflex.json new file mode 100644 index 000000000..6f6b99bd6 --- /dev/null +++ b/comfy/clip_vision_siglip2_base_naflex.json @@ -0,0 +1,14 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": -1, + "intermediate_size": 4304, + "model_type": "siglip2_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "num_patches": 256, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} From fd5c0755af18530ff1225f5946c7a647b9694032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 13 Jan 2026 00:28:59 +0200 Subject: [PATCH 113/181] Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829) --- comfy/ldm/lightricks/av_model.py | 136 ++++++++++++++++++++++++------- 1 file changed, 106 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 759535501..c12ace241 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier import comfy.ldm.common_dit +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space) + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + class BasicAVTransformerBlock(nn.Module): def __init__( self, @@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module): def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + num_ada_params = scale_shift_table.shape[0] ada_values = ( @@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module): gate_timestep, ) - scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] - gate_ada_values = [t.squeeze(2) for t in gate_ada_values] - - return (*scale_shift_chunks, *gate_ada_values) + return (*scale_shift_ada_values, *gate_ada_values) def forward( self, @@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) - v_embedded_timestep = v_embedded_timestep.view( - batch_size, -1, v_embedded_timestep.shape[-1] - ) + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + v_patches_per_frame = None + if orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: - a_timestep = a_timestep * self.timestep_scale_multiplier + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep.flatten(), + timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep.flatten() * av_ca_factor, + timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep.flatten() * av_ca_factor, + a_timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + a_timestep, a_embedded_timestep = self.audio_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) - a_embedded_timestep = a_embedded_timestep.view( - batch_size, -1, a_embedded_timestep.shape[-1] - ) - cross_av_timestep_ss = [ - av_ca_audio_scale_shift_timestep, - av_ca_video_scale_shift_timestep, - av_ca_a2v_gate_noise_timestep, - av_ca_v2a_gate_noise_timestep, - ] - cross_av_timestep_ss = list( - [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] - ) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) else: - a_timestep = timestep + a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] @@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel): ax = x[1] v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) # Process audio output From c2b65e2fceea821276c143ad52478552633922bf Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 13 Jan 2026 06:29:25 +0800 Subject: [PATCH 114/181] Update workflow templates to v0.8.0 (#11828) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6c1cd86d2..890070d5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.13 -comfyui-workflow-templates==0.7.69 +comfyui-workflow-templates==0.8.0 comfyui-embedded-docs==0.4.0 torch torchsde From ecaeeb990d7a5c3820b6f2373d04335d051d6b47 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 13 Jan 2026 11:18:01 +0800 Subject: [PATCH 115/181] chore: update workflow templates to v0.8.4 (#11835) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 890070d5d..077c8930a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.13 -comfyui-workflow-templates==0.8.0 +comfyui-workflow-templates==0.8.4 comfyui-embedded-docs==0.4.0 torch torchsde From b3c0e4de57bfd27e3dd94bd9723bb4c714668a09 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:33:54 -0800 Subject: [PATCH 116/181] Make loras work on nvfp4 models. (#11837) The initial applying is a bit slow but will probably be sped up in the future. --- comfy/float.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 2 +- comfy/quant_ops.py | 37 ++++++++++++++- requirements.txt | 2 +- 4 files changed, 150 insertions(+), 4 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..e638b1ff7 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + sign = torch.signbit(x).to(torch.uint8) + x_abs = x.abs() + + exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x_abs = x.abs() + exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x_abs * 2.0) + ).round().to(torch.uint8) + + fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(x.shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + max_abs = torch.amax(torch.abs(x), dim=-1) + block_scale = max_abs / F4_E2M1_MAX + scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) + scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) + + # Handle zero blocks (from padding): avoid 0/0 NaN + zero_scale_mask = (total_scale == 0) + total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) + + x = x / total_scale_safe.unsqueeze(-1) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) + + x = x.view(orig_shape) + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + + blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) + return data_lp, blocked_scales diff --git a/comfy/ops.py b/comfy/ops.py index 9c0b54ff4..415c39e92 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a..7a61203c3 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -7,7 +7,7 @@ try: QuantizedTensor, QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, - TensorCoreNVFP4Layout, # Direct import, no wrapper needed + TensorCoreNVFP4Layout as _CKNvfp4Layout, register_layout_op, register_layout_class, get_layout_class, @@ -34,7 +34,7 @@ except ImportError as e: class _CKFp8Layout: pass - class TensorCoreNVFP4Layout: + class _CKNvfp4Layout: pass def register_layout_class(name, cls): @@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreNVFP4Layout(_CKNvfp4Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) + + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params + + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn diff --git a/requirements.txt b/requirements.txt index 077c8930a..43737056e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.5 +comfy-kitchen>=0.2.6 #non essential dependencies: kornia>=0.7.1 From 117e7a5853cc34b6a012e06bb3efcc79ab314184 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 21:01:52 -0800 Subject: [PATCH 117/181] Refactor to try to lower mem usage. (#11840) --- comfy/float.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index e638b1ff7..c806af76b 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -69,26 +69,31 @@ def stochastic_rounding(value, dtype, seed=0): # TODO: improve this? def stochastic_float_to_fp4_e2m1(x, generator): + orig_shape = x.shape sign = torch.signbit(x).to(torch.uint8) - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + x = x.abs() + exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) mantissa = torch.where( exp > 0, - (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, - (x_abs * 2.0) + (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x * 2.0), + out=x ).round().to(torch.uint8) + del x - fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + exp = exp.to(torch.uint8) + + fp4 = (sign << 3) | (exp << 1) | mantissa + del sign, exp, mantissa fp4_flat = fp4.view(-1) packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] - return packed.reshape(list(x.shape)[:-1] + [-1]) + return packed.reshape(list(orig_shape)[:-1] + [-1]) def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: From acd0e536533cdf038bbaa32730cd12a438cc3a60 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 12 Jan 2026 21:15:24 -0800 Subject: [PATCH 118/181] Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) --- app/assets/database/bulk_ops.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py index 9352cd65d..c7b75290a 100644 --- a/app/assets/database/bulk_ops.py +++ b/app/assets/database/bulk_ops.py @@ -92,14 +92,23 @@ def seed_from_paths_batch( session.execute(ins_asset, chunk) # try to claim AssetCacheState (file_path) - winners_by_path: set[str] = set() + # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted ins_state = ( sqlite.insert(AssetCacheState) .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - .returning(AssetCacheState.file_path) ) for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - winners_by_path.update((session.execute(ins_state, chunk)).scalars().all()) + session.execute(ins_state, chunk) + + # Query to find which of our paths won (were actually inserted) + winners_by_path: set[str] = set() + for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetCacheState.file_path) + .where(AssetCacheState.file_path.in_(chunk)) + .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) + ) + winners_by_path.update(result.scalars().all()) all_paths_set = set(path_list) losers_by_path = all_paths_set - winners_by_path @@ -112,16 +121,23 @@ def seed_from_paths_batch( return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} # insert AssetInfo only for winners + # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] ins_info = ( sqlite.insert(AssetInfo) .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - .returning(AssetInfo.id) ) - - inserted_info_ids: set[str] = set() for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all()) + session.execute(ins_info, chunk) + + # Query to find which info rows were actually inserted (by matching our generated IDs) + all_info_ids = [row["id"] for row in winner_info_rows] + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) + ) + inserted_info_ids.update(result.scalars().all()) # build and insert tag + meta rows for the AssetInfo tag_rows: list[dict] = [] From 8af13b439bddaddb6d3b4f7c50b6391e88a10c66 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 12 Jan 2026 22:22:25 -0800 Subject: [PATCH 119/181] Update requirements.txt (#11841) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 43737056e..8650d28ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.36.13 +comfyui-frontend-package==1.36.14 comfyui-workflow-templates==0.8.4 comfyui-embedded-docs==0.4.0 torch From db9e6edfa1604be0b1f738b2f67495b46cee5a8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jan 2026 01:23:31 -0500 Subject: [PATCH 120/181] ComfyUI v0.9.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index df82ed4fc..09def2c70 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.2" +__version__ = "0.9.0" diff --git a/pyproject.toml b/pyproject.toml index 49f1a03fd..17aac8c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.2" +version = "0.9.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 1dcbd9efaf16c74a3aff2770f46de5b4aaaf927e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 22:42:07 -0800 Subject: [PATCH 121/181] Bump ltxav mem estimation a bit. (#11842) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d44c0bc37..1bf54f13f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -845,7 +845,7 @@ class LTXAV(LTXV): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = 0.061 # TODO + self.memory_usage_factor = 0.077 # TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.LTXAV(self, device=device) From 5ac13725331c1dfdf7aab977d4588b0a06a3debd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jan 2026 01:44:06 -0500 Subject: [PATCH 122/181] ComfyUI v0.9.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 09def2c70..0c9871e35 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.0" +__version__ = "0.9.1" diff --git a/pyproject.toml b/pyproject.toml index 17aac8c3f..dc52218b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.0" +version = "0.9.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From c543ad81c382c8450d2c8de62644c197c3c7416d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:30:13 +0200 Subject: [PATCH 123/181] fix(api-nodes-gemini): raise exception when no candidates due to safety block (#11848) --- comfy_api_nodes/nodes_gemini.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index e8ed7e797..35bbf0d2f 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -130,7 +130,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera Returns: List of response parts matching the requested type. """ - if response.candidates is None: + if not response.candidates: if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( @@ -141,14 +141,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] - for part in response.candidates[0].content.parts: - if part_type == "text" and part.text: - parts.append(part) - elif part.inlineData and part.inlineData.mimeType == part_type: - parts.append(part) - elif part.fileData and part.fileData.mimeType == part_type: - parts.append(part) - # Skip parts that don't match the requested type + blocked_reasons = [] + for candidate in response.candidates: + if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT": + blocked_reasons.append(candidate.finishReason) + continue + if candidate.content is None or candidate.content.parts is None: + continue + for part in candidate.content.parts: + if part_type == "text" and part.text: + parts.append(part) + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: + parts.append(part) + + if not parts and blocked_reasons: + raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}") + return parts From d9dc02a7d602a1918b9dabfc91890e6689f6f16d Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 13 Jan 2026 21:03:53 +0100 Subject: [PATCH 124/181] Support "lite" version of alibaba-pai Z-Image Controlnet (#11849) * reduced number of control layers (3) compared to full model --- comfy_extras/nodes_model_patch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 1355b3c93..f66d28fc9 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -244,6 +244,10 @@ class ModelPatchLoader: elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) config = {} + if 'control_layers.4.adaLN_modulation.0.weight' not in sd: + config['n_control_layers'] = 3 + config['additional_in_dim'] = 17 + config['refiner_control'] = True if 'control_layers.14.adaLN_modulation.0.weight' in sd: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 From e4b4fb34798a4710f670c81ae905ec24d58b6373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 14 Jan 2026 00:37:21 +0200 Subject: [PATCH 125/181] Load metadata on VAELoader (#11846) Needed to load the proper LTX2 VAE if separated from checkpoint --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 5a9d42d4a..90c5f2a6e 100644 --- a/nodes.py +++ b/nodes.py @@ -798,8 +798,8 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) From 79f6bb5e4fca0c2fbd5d09511a65449ca69332a8 Mon Sep 17 00:00:00 2001 From: ric-yu Date: Tue, 13 Jan 2026 16:14:40 -0800 Subject: [PATCH 126/181] add blueprints dir for built-in blueprints (#11853) --- app/subgraph_manager.py | 80 +++++++++++++++++++++------------- blueprints/put_blueprints_here | 0 2 files changed, 50 insertions(+), 30 deletions(-) create mode 100644 blueprints/put_blueprints_here diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py index dbe404541..6a8f586a4 100644 --- a/app/subgraph_manager.py +++ b/app/subgraph_manager.py @@ -10,6 +10,7 @@ import hashlib class Source: custom_node = "custom_node" + templates = "templates" class SubgraphEntry(TypedDict): source: str @@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict): class SubgraphManager: def __init__(self): self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None + + def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]: + """Create a subgraph entry from a file path. Expects normalized path (forward slashes).""" + entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": source, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": {"node_pack": node_pack}, + } + return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): with open(entry['path'], 'r') as f: @@ -60,53 +73,60 @@ class SubgraphManager: return entries async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): - # if not forced to reload and cached, return cache + """Load subgraphs from custom nodes.""" if not force_reload and self.cached_custom_node_subgraphs is not None: return self.cached_custom_node_subgraphs - # Load subgraphs from custom nodes - subfolder = "subgraphs" - subgraphs_dict: dict[SubgraphEntry] = {} + subgraphs_dict: dict[SubgraphEntry] = {} for folder in folder_paths.get_folder_paths("custom_nodes"): - pattern = os.path.join(folder, f"*/{subfolder}/*.json") - matched_files = glob.glob(pattern) - for file in matched_files: - # replace backslashes with forward slashes + pattern = os.path.join(folder, "*/subgraphs/*.json") + for file in glob.glob(pattern): file = file.replace('\\', '/') - info: CustomNodeSubgraphEntryInfo = { - "node_pack": "custom_nodes." + file.split('/')[-3] - } - source = Source.custom_node - # hash source + path to make sure id will be as unique as possible, but - # reproducible across backend reloads - id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() - entry: SubgraphEntry = { - "source": Source.custom_node, - "name": os.path.splitext(os.path.basename(file))[0], - "path": file, - "info": info, - } - subgraphs_dict[id] = entry + node_pack = "custom_nodes." + file.split('/')[-3] + entry_id, entry = self._create_entry(file, Source.custom_node, node_pack) + subgraphs_dict[entry_id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict return subgraphs_dict - async def get_custom_node_subgraph(self, id: str, loadedModules): - subgraphs = await self.get_custom_node_subgraphs(loadedModules) - entry: SubgraphEntry = subgraphs.get(id, None) - if entry is not None and entry.get('data', None) is None: + async def get_blueprint_subgraphs(self, force_reload=False): + """Load subgraphs from the blueprints directory.""" + if not force_reload and self.cached_blueprint_subgraphs is not None: + return self.cached_blueprint_subgraphs + + subgraphs_dict: dict[SubgraphEntry] = {} + blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints') + + if os.path.exists(blueprints_dir): + for file in glob.glob(os.path.join(blueprints_dir, "*.json")): + file = file.replace('\\', '/') + entry_id, entry = self._create_entry(file, Source.templates, "comfyui") + subgraphs_dict[entry_id] = entry + + self.cached_blueprint_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_all_subgraphs(self, loadedModules, force_reload=False): + """Get all subgraphs from all sources (custom nodes and blueprints).""" + custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload) + blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload) + return {**custom_node_subgraphs, **blueprint_subgraphs} + + async def get_subgraph(self, id: str, loadedModules): + """Get a specific subgraph by ID from any source.""" + entry = (await self.get_all_subgraphs(loadedModules)).get(id) + if entry is not None and entry.get('data') is None: await self.load_entry_data(entry) return entry def add_routes(self, routes, loadedModules): @routes.get("/global_subgraphs") async def get_global_subgraphs(request): - subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) - # NOTE: we may want to include other sources of global subgraphs such as templates in the future; - # that's the reasoning for the current implementation + subgraphs_dict = await self.get_all_subgraphs(loadedModules) return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) @routes.get("/global_subgraphs/{id}") async def get_global_subgraph(request): id = request.match_info.get("id", None) - subgraph = await self.get_custom_node_subgraph(id, loadedModules) + subgraph = await self.get_subgraph(id, loadedModules) return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/blueprints/put_blueprints_here b/blueprints/put_blueprints_here new file mode 100644 index 000000000..e69de29bb From 1419047fdbdf26b2311950c041a86fd998a2acbd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 14 Jan 2026 02:18:28 +0200 Subject: [PATCH 127/181] [Api Nodes]: Improve Price Badge Declarations (#11582) * api nodes: price badges moved to nodes code * added price badges for 4 more node-packs * added price badges for 10 more node-packs * added new price badges for Omni STD mode * add support for autogrow groups * use full names for "widgets", "inputs" and "groups" * add strict typing for JSONata rules * add price badge for WanReferenceVideoApi node * add support for DynamicCombo * sync price badges changes (https://github.com/Comfy-Org/ComfyUI_frontend/pull/7900) * sync badges for Vidu2 nodes * fixed incorrect price for RecraftCrispUpscaleNode * fixed incorrect price badges for LTXV nodes * fixed price badge for MinimaxHailuoVideoNode * fixed price badges for PixVerse nodes --- comfy_api/latest/_io.py | 79 +++++++++- comfy_api_nodes/nodes_bfl.py | 44 ++++++ comfy_api_nodes/nodes_bytedance.py | 66 +++++++++ comfy_api_nodes/nodes_gemini.py | 40 ++++++ comfy_api_nodes/nodes_ideogram.py | 42 +++++- comfy_api_nodes/nodes_kling.py | 215 ++++++++++++++++++++++++++++ comfy_api_nodes/nodes_ltxv.py | 18 +++ comfy_api_nodes/nodes_luma.py | 76 ++++++++++ comfy_api_nodes/nodes_minimax.py | 20 +++ comfy_api_nodes/nodes_moonvalley.py | 12 ++ comfy_api_nodes/nodes_openai.py | 127 ++++++++++++++++ comfy_api_nodes/nodes_pixverse.py | 30 ++++ comfy_api_nodes/nodes_recraft.py | 32 +++++ comfy_api_nodes/nodes_rodin.py | 12 ++ comfy_api_nodes/nodes_runway.py | 15 ++ comfy_api_nodes/nodes_sora.py | 18 +++ comfy_api_nodes/nodes_stability.py | 31 ++++ comfy_api_nodes/nodes_tripo.py | 164 +++++++++++++++++++++ comfy_api_nodes/nodes_veo2.py | 42 ++++++ comfy_api_nodes/nodes_vidu.py | 100 +++++++++++++ comfy_api_nodes/nodes_wan.py | 43 ++++++ 21 files changed, 1221 insertions(+), 5 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 50143ff53..e6a0d1821 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1225,6 +1225,7 @@ class NodeInfoV1: deprecated: bool=None experimental: bool=None api_node: bool=None + price_badge: dict | None = None @dataclass class NodeInfoV3: @@ -1234,11 +1235,77 @@ class NodeInfoV3: name: str=None display_name: str=None description: str=None + python_module: Any = None category: str=None output_node: bool=None deprecated: bool=None experimental: bool=None api_node: bool=None + price_badge: dict | None = None + + +@dataclass +class PriceBadgeDepends: + widgets: list[str] = field(default_factory=list) + inputs: list[str] = field(default_factory=list) + input_groups: list[str] = field(default_factory=list) + + def validate(self) -> None: + if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets): + raise ValueError("PriceBadgeDepends.widgets must be a list[str].") + if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs): + raise ValueError("PriceBadgeDepends.inputs must be a list[str].") + if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups): + raise ValueError("PriceBadgeDepends.input_groups must be a list[str].") + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + # Build lookup: widget_id -> io_type + input_types: dict[str, str] = {} + for inp in schema_inputs: + all_inputs = inp.get_all() + input_types[inp.id] = inp.get_io_type() # First input is always the parent itself + for nested_inp in all_inputs[1:]: + # For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID + # to match frontend naming convention (e.g., "should_texture.enable_pbr") + prefixed_id = f"{inp.id}.{nested_inp.id}" + input_types[prefixed_id] = nested_inp.get_io_type() + + # Enrich widgets with type information, raising error for unknown widgets + widgets_data: list[dict[str, str]] = [] + for w in self.widgets: + if w not in input_types: + raise ValueError( + f"PriceBadge depends_on.widgets references unknown widget '{w}'. " + f"Available widgets: {list(input_types.keys())}" + ) + widgets_data.append({"name": w, "type": input_types[w]}) + + return { + "widgets": widgets_data, + "inputs": self.inputs, + "input_groups": self.input_groups, + } + + +@dataclass +class PriceBadge: + expr: str + depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends) + engine: str = field(default="jsonata") + + def validate(self) -> None: + if self.engine != "jsonata": + raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.") + if not isinstance(self.expr, str) or not self.expr.strip(): + raise ValueError("PriceBadge.expr must be a non-empty string.") + self.depends_on.validate() + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + return { + "engine": self.engine, + "depends_on": self.depends_on.as_dict(schema_inputs), + "expr": self.expr, + } @dataclass @@ -1284,6 +1351,8 @@ class Schema: """Flags a node as experimental, informing users that it may change or not work as expected.""" is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + price_badge: PriceBadge | None = None + """Optional client-evaluated pricing badge declaration for this node.""" not_idempotent: bool=False """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False @@ -1314,6 +1383,8 @@ class Schema: input.validate() for output in self.outputs: output.validate() + if self.price_badge is not None: + self.price_badge.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1387,7 +1458,8 @@ class Schema: deprecated=self.is_deprecated, experimental=self.is_experimental, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, ) return info @@ -1419,7 +1491,8 @@ class Schema: deprecated=self.is_deprecated, experimental=self.is_experimental, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, ) return info @@ -1971,4 +2044,6 @@ __all__ = [ "add_to_dict_v3", "V3Data", "ImageCompare", + "PriceBadgeDepends", + "PriceBadge", ] diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index ce077d6b3..76021ef7f 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.06}""", + ), ) @classmethod @@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" DISPLAY_NAME = "Flux.2 [pro] Image" API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.03 + 0.015 * ($outMP - 1); + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.015, + "max_usd": $outputCost + 0.12, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ @classmethod def define_schema(cls) -> IO.Schema: @@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]), + expr=cls.PRICE_BADGE_EXPR, + ), ) @classmethod @@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode): NODE_ID = "Flux2MaxImageNode" DISPLAY_NAME = "Flux.2 [max] Image" API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.07 + 0.03 * ($outMP - 1); + + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.03, + "max_usd": $outputCost + 0.24, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ class BFLExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index d4a2cfae6..f09a4a0ed 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -126,6 +126,9 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -367,6 +370,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03; + { + "type":"usd", + "usd": $price, + "format": { "suffix":" x images/Run", "approximate": true } + } + ) + """, + ), ) @classmethod @@ -522,6 +538,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -632,6 +649,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -754,6 +772,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -877,6 +896,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -946,6 +966,52 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: ) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56], + "1080p":[1.18,1.22] + }, + "seedance-1-0-pro-fast": { + "480p":[0.09,0.1], + "720p":[0.21,0.23], + "1080p":[0.47,0.49] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41], + "1080p":[0.85,0.88] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "1080") ? "1080p" : + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, +) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 35bbf0d2f..a2daea50a 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -319,6 +319,30 @@ class GeminiNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "gemini-2.5-flash") ? { + "type": "list_usd", + "usd": [0.0003, 0.0025], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"} + } + : $contains($m, "gemini-2.5-pro") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gemini-3-pro-preview") ? { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type":"text", "text":"Token-based"} + ) + """, + ), ) @classmethod @@ -580,6 +604,9 @@ class GeminiImage(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""", + ), ) @classmethod @@ -710,6 +737,19 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $r := widgets.resolution; + ($contains($r,"1k") or $contains($r,"2k")) + ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} + : $contains($r,"4k") + ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} + : {"type":"text","text":"Token-based"} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 48f94e612..827b3523a 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0286 : 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0715 : 0.1144; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]), + expr=""" + ( + $n := widgets.num_images; + $speed := widgets.rendering_speed; + $hasChar := inputs.character_image.connected; + $base := + $contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) : + $contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) : + $contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) : + 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 01d9c34f5..05dde88b1 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -764,6 +764,33 @@ class KlingTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -818,6 +845,16 @@ class OmniProTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -886,6 +923,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -981,6 +1028,16 @@ class OmniProImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1056,6 +1113,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1142,6 +1209,16 @@ class OmniProEditVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod @@ -1228,6 +1305,9 @@ class OmniProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.028}""", + ), ) @classmethod @@ -1313,6 +1393,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14}""", + ), ) @classmethod @@ -1375,6 +1458,33 @@ class KlingImage2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + $contains($model,"v2-5-turbo") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : ($contains($model,"v2-1-master") or $contains($model,"v2-master")) + ? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1448,6 +1558,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.49}""", + ), ) @classmethod @@ -1518,6 +1631,33 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1583,6 +1723,9 @@ class KlingVideoExtendNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.28}""", + ), ) @classmethod @@ -1664,6 +1807,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + ($contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1728,6 +1894,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]), + expr=""" + ( + ($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom")) + ? {"type":"usd","usd":0.49} + : {"type":"usd","usd":0.28} + ) + """, + ), ) @classmethod @@ -1782,6 +1958,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1842,6 +2021,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1892,6 +2074,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.7}""", + ), ) @classmethod @@ -1991,6 +2176,19 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]), + expr=""" + ( + $m := widgets.model_name; + $base := + $contains($m,"kling-v1-5") + ? (inputs.image.connected ? 0.028 : 0.014) + : ($contains($m,"kling-v1") ? 0.0035 : 0.014); + {"type":"usd","usd": $base * widgets.n} + ) + """, + ), ) @classmethod @@ -2074,6 +2272,10 @@ class TextToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2138,6 +2340,10 @@ class ImageToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2218,6 +2424,15 @@ class MotionControl(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $prices := {"std": 0.07, "pro": 0.112}; + {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 7e61560dc..c6424af92 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel): image_uri: str | None = Field(None) +PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $prices := { + "ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24}, + "ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16} + }; + $modelPrices := $lookup($prices, $lowercase(widgets.model)); + $pps := $lookup($modelPrices, widgets.resolution); + {"type":"usd","usd": $pps * widgets.duration} + ) + """, +) + + class TextToVideoNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod @@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 894f2b08c..95cb442e5 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, + ) @classmethod @@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return LumaKeyframes(frame0=frame0, frame1=frame1) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]), + expr=""" + ( + $p := { + "ray-flash-2": { + "5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2}, + "9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36} + }, + "ray-2": { + "5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57}, + "9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03} + } + }; + + $m := widgets.model; + $d := widgets.duration; + $r := widgets.resolution; + + $modelKey := + $contains($m,"ray-flash-2") ? "ray-flash-2" : + $contains($m,"ray-2") ? "ray-2" : + $contains($m,"ray-1-6") ? "ray-1-6" : + "other"; + + $durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : ""; + $resKey := + $contains($r,"4k") ? "4k" : + $contains($r,"1080p") ? "1080p" : + $contains($r,"720p") ? "720p" : + $contains($r,"540p") ? "540p" : ""; + + $modelPrices := $lookup($p, $modelKey); + $durPrices := $lookup($modelPrices, $durKey); + $v := $lookup($durPrices, $resKey); + + $price := + ($modelKey = "ray-1-6") ? 0.5 : + ($modelKey = "other") ? 0.79 : + ($exists($v) ? $v : 0.79); + + {"type":"usd","usd": $price} + ) + """, +) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 05cbb700f..43a15d50d 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $prices := { + "768p": {"6": 0.28, "10": 0.56}, + "1080p": {"6": 0.49} + }; + $resPrices := $lookup($prices, $lowercase(widgets.resolution)); + $price := $lookup($resPrices, $string(widgets.duration)); + {"type":"usd","usd": $price ? $price : 0.43} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 2771e4790..769b171b7 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod @@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 2.25}""", + ), ) @classmethod @@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a6205a34f..2f144c5c3 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -160,6 +160,23 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]), + expr=""" + ( + $size := widgets.size; + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; + + $base := + $contains($size, "256x256") ? 0.016 : + $contains($size, "512x512") ? 0.018 : + 0.02; + + {"type":"usd","usd": $round($base * $n, 3)} + ) + """, + ), ) @classmethod @@ -287,6 +304,25 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]), + expr=""" + ( + $size := widgets.size; + $q := widgets.quality; + $hd := $contains($q, "hd"); + + $price := + $contains($size, "1024x1024") + ? ($hd ? 0.08 : 0.04) + : (($contains($size, "1792x1024") or $contains($size, "1024x1792")) + ? ($hd ? 0.12 : 0.08) + : 0.04); + + {"type":"usd","usd": $price} + ) + """, + ), ) @classmethod @@ -411,6 +447,28 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + expr=""" + ( + $ranges := { + "low": [0.011, 0.02], + "medium": [0.046, 0.07], + "high": [0.167, 0.3] + }; + $range := $lookup($ranges, widgets.quality); + $n := widgets.n; + ($n = 1) + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + : { + "type":"range_usd", + "min_usd": $range[0], + "max_usd": $range[1], + "format": { "suffix": " x " & $string($n) & "/Run" } + } + ) + """, + ), ) @classmethod @@ -566,6 +624,75 @@ class OpenAIChatNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "o4-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1-pro") ? { + "type": "list_usd", + "usd": [0.15, 0.6], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1") ? { + "type": "list_usd", + "usd": [0.015, 0.06], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3") ? { + "type": "list_usd", + "usd": [0.01, 0.04], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4o") ? { + "type": "list_usd", + "usd": [0.0025, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-nano") ? { + "type": "list_usd", + "usd": [0.0001, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-mini") ? { + "type": "list_usd", + "usd": [0.0004, 0.0016], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1") ? { + "type": "list_usd", + "usd": [0.002, 0.008], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-nano") ? { + "type": "list_usd", + "usd": [0.00005, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-mini") ? { + "type": "list_usd", + "usd": [0.00025, 0.002], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type": "text", "text": "Token-based"} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 6e1686af0..86ddb3ab9 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]), + expr=""" + ( + $prices := { + "5": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 0.6, "fast": 1.2}, + "540p": {"normal": 0.45, "fast": 0.9}, + "360p": {"normal": 0.45, "fast": 0.9} + }, + "8": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 1.2, "fast": 1.2}, + "540p": {"normal": 0.9, "fast": 1.2}, + "360p": {"normal": 0.9, "fast": 1.2} + } + }; + $durPrices := $lookup($prices, $string(widgets.duration_seconds)); + $qualityPrices := $lookup($durPrices, widgets.quality); + $price := $lookup($qualityPrices, widgets.motion_mode); + {"type":"usd","usd": $price ? $price : 0.9} + ) + """, +) + + class PixVerseExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e3440b946..05dc151ad 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -378,6 +378,10 @@ class RecraftTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -490,6 +494,10 @@ class RecraftImageToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -591,6 +599,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -692,6 +704,10 @@ class RecraftTextToVectorNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""", + ), ) @classmethod @@ -759,6 +775,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 0.01}""", + ), ) @classmethod @@ -817,6 +837,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), ) @classmethod @@ -883,6 +906,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -929,6 +955,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.004}""", + ), ) @classmethod @@ -972,6 +1001,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e60e7a6d6..b4420cb93 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 3c55039c9..d19fdb365 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 92b225d40..87e663845 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]), + expr=""" + ( + $m := widgets.model; + $size := widgets.size; + $dur := widgets.duration; + $isPro := $contains($m, "sora-2-pro"); + $isSora2 := $contains($m, "sora-2"); + $isProSize := ($size = "1024x1792" or $size = "1792x1024"); + $perSec := + $isPro ? ($isProSize ? 0.5 : 0.3) : + $isSora2 ? 0.1 : + ($isProSize ? 0.5 : 0.1); + {"type":"usd","usd": $round($perSec * $dur, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index bb7ceed78..5c48c1f1e 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.08}""", + ), ) @classmethod @@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model,"large") + ? {"type":"usd","usd":0.065} + : {"type":"usd","usd":0.035} + ) + """, + ), ) @classmethod @@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index e72f8e96a..aa790143d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 20 : ($withTexture ? 20 : 10); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]), + expr=""" + ( + $tq := widgets.texture_quality; + {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} + ) + """, + ), ) @classmethod @@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.3}""", + ), ) @classmethod @@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1}""", + ), ) @classmethod @@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "quad", + "face_limit", + "texture_size", + "texture_format", + "force_symmetry", + "flatten_bottom", + "flatten_bottom_threshold", + "pivot_to_center_bottom", + "scale_factor", + "with_animation", + "pack_uv", + "bake", + "part_names", + "fbx_preset", + "export_vertex_colors", + "export_orientation", + "animate_in_place", + ], + ), + expr=""" + ( + $face := (widgets.face_limit != null) ? widgets.face_limit : -1; + $texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096; + $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; + $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; + $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); + $part := widgets.part_names; + $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); + $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); + $advanced := + widgets.quad or + widgets.force_symmetry or + widgets.flatten_bottom or + widgets.pivot_to_center_bottom or + widgets.with_animation or + widgets.pack_uv or + widgets.bake or + widgets.export_vertex_colors or + widgets.animate_in_place or + ($face != -1) or + ($texSize != 4096) or + ($flatThresh != 0) or + ($scale != 1) or + ($texFmt != "jpeg") or + ($part != "") or + ($fbx != "blender") or + ($orient != "default"); + {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13a6bfd91..c14d6ad68 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]), + expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""", + ), ) @classmethod @@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + expr=""" + ( + $m := widgets.model; + $a := widgets.generate_audio; + ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) + ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} + : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) + ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} + : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + ) + """, + ), ) @@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + expr=""" + ( + $prices := { + "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, + "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } + }; + $m := widgets.model; + $ga := (widgets.generate_audio = "true"); + $seconds := widgets.duration; + $modelKey := + $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : + $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : + ""; + $audioKey := $ga ? "audio" : "no_audio"; + $modelPrices := $lookup($prices, $modelKey); + $pps := $lookup($modelPrices, $audioKey); + ($pps != null) + ? {"type":"usd","usd": $pps * $seconds} + : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 9d94ae7ad..8edb02f39 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -121,6 +121,9 @@ class ViduTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -214,6 +217,9 @@ class ViduImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -317,6 +323,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -426,6 +435,9 @@ class ViduStartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -507,6 +519,17 @@ class Vidu2TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.1 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1)} + ) + """, + ), ) @classmethod @@ -594,6 +617,39 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 1 ? {"type":"usd","usd": 0.04} + : $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), ) @classmethod @@ -698,6 +754,18 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.375 : 0.125; + $perSec := $is1080 ? 0.05 : 0.025; + $audioCost := widgets.audio = true ? 0.075 : 0; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost} + ) + """, + ), ) @classmethod @@ -804,6 +872,38 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 3e04786a9..a1355d4f1 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -244,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -363,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -520,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $resKey := $substringBefore(widgets.size, ":"); + $pps := $lookup($ppsTable, $resKey); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -681,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $pps := $lookup($ppsTable, widgets.resolution); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -828,6 +855,22 @@ class WanReferenceVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]), + expr=""" + ( + $rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10; + $inputMin := 2 * $rate; + $inputMax := 5 * $rate; + $outputPrice := widgets.duration * $rate; + { + "type": "range_usd", + "min_usd": $inputMin + $outputPrice, + "max_usd": $inputMax + $outputPrice + } + ) + """, + ), ) @classmethod From 15b312de7a74a836fa45b989a7697895b01e0cbf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:23:58 -0800 Subject: [PATCH 128/181] Optimize nvfp4 lora applying. (#11854) --- comfy/float.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index c806af76b..1a6070bff 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -165,20 +165,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_scale = max_abs / F4_E2M1_MAX scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) - - # Handle zero blocks (from padding): avoid 0/0 NaN - zero_scale_mask = (total_scale == 0) - total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) - - x = x / total_scale_safe.unsqueeze(-1) + x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) - - x = x.view(orig_shape) + x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) From eff2b9d412932aa7d49e6302cdf6e7cf24808b6f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:37:19 -0800 Subject: [PATCH 129/181] Optimize nvfp4 lora applying. (#11856) --- comfy/float.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 1a6070bff..8c303bea0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -161,10 +161,7 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_size = 16 x = x.reshape(orig_shape[0], -1, block_size) - max_abs = torch.amax(torch.abs(x), dim=-1) - block_scale = max_abs / F4_E2M1_MAX - scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) - scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) @@ -172,6 +169,5 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) return data_lp, blocked_scales From 469dd9c16ad88765ffe4e7bfa57dd80faafbaddf Mon Sep 17 00:00:00 2001 From: nomadoor <124905471+nomadoor@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:48:10 +0900 Subject: [PATCH 130/181] Adds crop to multiple mode to ResizeImageMaskNode. (#11838) * Add crop-to-multiple resize mode * Make scale-to-multiple shape handling explicit --- comfy_extras/nodes_post_processing.py | 44 +++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 01afa13a1..0433bbda2 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -254,6 +254,7 @@ class ResizeType(str, Enum): SCALE_HEIGHT = "scale height" SCALE_TOTAL_PIXELS = "scale total pixels" MATCH_SIZE = "match size" + SCALE_TO_MULTIPLE = "scale to multiple" def is_image(input: torch.Tensor) -> bool: # images have 4 dimensions: [batch, height, width, channels] @@ -363,6 +364,43 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str input = finalize_image_mask_input(input, is_type_image) return input +def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor: + if multiple <= 1: + return input + is_type_image = is_image(input) + if is_type_image: + _, height, width, _ = input.shape + else: + _, height, width = input.shape + target_w = (width // multiple) * multiple + target_h = (height // multiple) * multiple + if target_w == 0 or target_h == 0: + return input + if target_w == width and target_h == height: + return input + s_w = target_w / width + s_h = target_h / height + if s_w >= s_h: + scaled_w = target_w + scaled_h = int(math.ceil(height * s_w)) + if scaled_h < target_h: + scaled_h = target_h + else: + scaled_h = target_h + scaled_w = int(math.ceil(width * s_h)) + if scaled_w < target_w: + scaled_w = target_w + input = init_image_mask_input(input, is_type_image) + input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + x0 = (scaled_w - target_w) // 2 + y0 = (scaled_h - target_h) // 2 + x1 = x0 + target_w + y1 = y0 + target_h + if is_type_image: + return input[:, y0:y1, x0:x1, :] + return input[:, y0:y1, x0:x1] + class ResizeImageMaskNode(io.ComfyNode): scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @@ -378,6 +416,7 @@ class ResizeImageMaskNode(io.ComfyNode): longer_size: int shorter_size: int megapixels: float + multiple: int @classmethod def define_schema(cls): @@ -417,6 +456,9 @@ class ResizeImageMaskNode(io.ComfyNode): io.MultiType.Input("match", [io.Image, io.Mask]), crop_combo, ]), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), + ]), ]), io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), ], @@ -442,6 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode): return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) elif selected_type == ResizeType.MATCH_SIZE: return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_TO_MULTIPLE: + return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method)) raise ValueError(f"Unsupported resize type: {selected_type}") def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: From 7eb959ce934da914523b455f9a6e7e0662690325 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Tue, 13 Jan 2026 18:03:16 -0800 Subject: [PATCH 131/181] fix: update ComfyUI repo reference to Comfy-Org/ComfyUI (#11858) --- .github/workflows/test-launch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index ef0d3f123..581c0474b 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -13,7 +13,7 @@ jobs: - name: Checkout ComfyUI uses: actions/checkout@v4 with: - repository: "comfyanonymous/ComfyUI" + repository: "Comfy-Org/ComfyUI" path: "ComfyUI" - uses: actions/setup-python@v4 with: From c9196f355ef5832daf55c4bbe8c6279dec509331 Mon Sep 17 00:00:00 2001 From: nomadoor <124905471+nomadoor@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:25:09 +0900 Subject: [PATCH 132/181] Fix scale_shorter_dimension portrait check (#11862) --- comfy_extras/nodes_post_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 0433bbda2..2e559c35c 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -329,7 +329,7 @@ def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method if height < width: width = round((width / height) * shorter_size) height = shorter_size - elif width > height: + elif width < height: height = round((height / width) * shorter_size) width = shorter_size else: From ac4d8ea9b32f56410860dccdb30ae50a1029d6fd Mon Sep 17 00:00:00 2001 From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com> Date: Wed, 14 Jan 2026 04:39:22 +0100 Subject: [PATCH 133/181] feat: add CI container version bump automation (#11692) * feat: add CI container version bump automation Adds a workflow that triggers on releases to create PRs in the comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile. Supports both release events and manual workflow dispatch for testing. * feat: add CI container version bump automation Adds a workflow that triggers on releases to create PRs in the comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile. Supports both release events and manual workflow dispatch for testing. * ci: update CI container repository owner * refactor: rename `update-ci-container.yaml` workflow to `update-ci-container.yml` * Remove post-merge instructions from the CI container update workflow. --- .github/workflows/update-ci-container.yml | 59 +++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .github/workflows/update-ci-container.yml diff --git a/.github/workflows/update-ci-container.yml b/.github/workflows/update-ci-container.yml new file mode 100644 index 000000000..f7972e056 --- /dev/null +++ b/.github/workflows/update-ci-container.yml @@ -0,0 +1,59 @@ +name: "CI: Update CI Container" + +on: + release: + types: [published] + workflow_dispatch: + inputs: + version: + description: 'ComfyUI version (e.g., v0.7.0)' + required: true + type: string + +jobs: + update-ci-container: + runs-on: ubuntu-latest + # Skip pre-releases unless manually triggered + if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease + steps: + - name: Get version + id: version + run: | + if [ "${{ github.event_name }}" = "release" ]; then + VERSION="${{ github.event.release.tag_name }}" + else + VERSION="${{ inputs.version }}" + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Checkout comfyui-ci-container + uses: actions/checkout@v4 + with: + repository: comfy-org/comfyui-ci-container + token: ${{ secrets.CI_CONTAINER_PAT }} + + - name: Check current version + id: current + run: | + CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown") + echo "current_version=$CURRENT" >> $GITHUB_OUTPUT + + - name: Update Dockerfile + run: | + VERSION="${{ steps.version.outputs.version }}" + sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile + + - name: Create Pull Request + id: create-pr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.CI_CONTAINER_PAT }} + branch: automation/comfyui-${{ steps.version.outputs.version }} + title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" + body: | + Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}` + + **Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }} + + labels: automation + commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" From 712cca36a13db93a9fa1fde9b7b5f9a5b961209a Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Wed, 14 Jan 2026 04:41:44 +0100 Subject: [PATCH 134/181] feat: throttle ProgressBar updates to reduce WebSocket flooding (#11504) --- comfy/utils.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index ffa98c9b1..fac13f128 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -30,6 +30,7 @@ from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args import json +import time MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1097,6 +1098,10 @@ def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function +# Throttle settings for progress bar updates to reduce WebSocket flooding +PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates +PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change + class ProgressBar: def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK @@ -1104,6 +1109,8 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK self.node_id = node_id + self._last_update_time = 0.0 + self._last_sent_value = -1 def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1112,7 +1119,29 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview, node_id=self.node_id) + current_time = time.perf_counter() + is_first = (self._last_sent_value < 0) + is_final = (value >= self.total) + has_preview = (preview is not None) + + # Always send immediately for previews, first update, or final update + if has_preview or is_first or is_final: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value + return + + # Apply throttling for regular progress updates + if self.total > 0: + percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 + else: + percent_changed = 100 + time_elapsed = current_time - self._last_update_time + + if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value def update(self, value): self.update_absolute(self.current + value) From 6165c38cb58c40b15ade879b80051b6c9148587f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:49:38 -0800 Subject: [PATCH 135/181] Optimize nvfp4 lora applying. (#11866) This changes results a bit but it also speeds up things a lot. --- comfy/float.py | 56 ++++++++++++++++++++++++++++++++------- comfy/quant_ops.py | 2 +- comfy/supported_models.py | 2 +- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 8c303bea0..88c47cd80 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: return rearranged.reshape(padded_rows, padded_cols) -def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): +def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): F4_E2M1_MAX = 6.0 F8_E4M3_MAX = 448.0 + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + + x = x.view(orig_shape).nan_to_num() + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + return data_lp, scaled_block_scales_fp8 + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) + return x, to_blocked(blocked_scaled, flatten=False) + + +def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): def roundup(x: int, multiple: int) -> int: """Round up x to the nearest multiple.""" return ((x + multiple - 1) // multiple) * multiple @@ -158,16 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): # what we want to produce. If we pad here, we want the padded output. orig_shape = x.shape - block_size = 16 + orig_shape = list(orig_shape) - x = x.reshape(orig_shape[0], -1, block_size) - scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) + output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = x.view(orig_shape).nan_to_num() - data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) - return data_lp, blocked_scales + num_slices = max(1, (x.numel() / block_size)) + slice_size = max(1, (round(x.shape[0] / num_slices))) + + for i in range(0, x.shape[0], slice_size): + fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) + output_fp4[i:i + slice_size].copy_(fp4) + output_block[i:i + slice_size].copy_(block) + + return output_fp4, to_blocked(output_block, flatten=False) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 7a61203c3..15a4f457b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -104,7 +104,7 @@ class TensorCoreNVFP4Layout(_CKNvfp4Layout): needs_padding = padded_shape != orig_shape if stochastic_rounding > 0: - qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) else: qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1bf54f13f..2c4c6b8fc 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1042,7 +1042,7 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 2.0 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32] From d1504404662dfce6e401422701c2a7e24057b1b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:54:50 -0800 Subject: [PATCH 136/181] Fix VAELoader (#11880) --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 90c5f2a6e..aa8572446 100644 --- a/nodes.py +++ b/nodes.py @@ -788,6 +788,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): + metadata = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) From 07f2462eae7fa2daa34971dd1b15fd525686e958 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 14 Jan 2026 21:25:38 +0200 Subject: [PATCH 137/181] feat(api-nodes): add Meshy 3D nodes (#11843) * feat(api-nodes): add Meshy 3D nodes * rebased, added JSONata price badges --- comfy_api_nodes/apis/meshy.py | 160 +++++ comfy_api_nodes/nodes_meshy.py | 790 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 23 +- nodes.py | 1 + 4 files changed, 969 insertions(+), 5 deletions(-) create mode 100644 comfy_api_nodes/apis/meshy.py create mode 100644 comfy_api_nodes/nodes_meshy.py diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py new file mode 100644 index 000000000..be46d0d58 --- /dev/null +++ b/comfy_api_nodes/apis/meshy.py @@ -0,0 +1,160 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + +from comfy_api.latest import Input + + +class InputShouldRemesh(TypedDict): + should_remesh: str + topology: str + target_polycount: int + + +class InputShouldTexture(TypedDict): + should_texture: str + enable_pbr: bool + texture_prompt: str + texture_image: Input.Image | None + + +class MeshyTaskResponse(BaseModel): + result: str = Field(...) + + +class MeshyTextToModelRequest(BaseModel): + mode: str = Field("preview") + prompt: str = Field(..., max_length=600) + art_style: str = Field(..., description="'realistic' or 'sculpture'") + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + pose_mode: str = Field(...) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRefineTask(BaseModel): + mode: str = Field("refine") + preview_task_id: str = Field(...) + enable_pbr: bool | None = Field(...) + texture_prompt: str | None = Field(...) + texture_image_url: str | None = Field(...) + ai_model: str = Field(...) + moderation: bool = Field(False) + + +class MeshyImageToModelRequest(BaseModel): + image_url: str = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyMultiImageToModelRequest(BaseModel): + image_urls: list[str] = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRiggingRequest(BaseModel): + input_task_id: str = Field(...) + height_meters: float = Field(...) + texture_image_url: str | None = Field(...) + + +class MeshyAnimationRequest(BaseModel): + rig_task_id: str = Field(...) + action_id: int = Field(...) + + +class MeshyTextureRequest(BaseModel): + input_task_id: str = Field(...) + ai_model: str = Field(...) + enable_original_uv: bool = Field(...) + enable_pbr: bool = Field(...) + text_style_prompt: str | None = Field(...) + image_style_url: str | None = Field(...) + + +class MeshyModelsUrls(BaseModel): + glb: str = Field("") + + +class MeshyRiggedModelsUrls(BaseModel): + rigged_character_glb_url: str = Field("") + + +class MeshyAnimatedModelsUrls(BaseModel): + animation_glb_url: str = Field("") + + +class MeshyResultTextureUrls(BaseModel): + base_color: str = Field(...) + metallic: str | None = Field(None) + normal: str | None = Field(None) + roughness: str | None = Field(None) + + +class MeshyTaskError(BaseModel): + message: str | None = Field(None) + + +class MeshyModelResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + model_urls: MeshyModelsUrls = Field(MeshyModelsUrls()) + thumbnail_url: str = Field(...) + video_url: str | None = Field(None) + status: str = Field(...) + progress: int = Field(0) + texture_urls: list[MeshyResultTextureUrls] | None = Field([]) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyRiggedResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyRiggedModelsUrls = Field(MeshyRiggedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyAnimationResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyAnimatedModelsUrls = Field(MeshyAnimatedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py new file mode 100644 index 000000000..740607983 --- /dev/null +++ b/comfy_api_nodes/nodes_meshy.py @@ -0,0 +1,790 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.meshy import ( + InputShouldRemesh, + InputShouldTexture, + MeshyAnimationRequest, + MeshyAnimationResult, + MeshyImageToModelRequest, + MeshyModelResult, + MeshyMultiImageToModelRequest, + MeshyRefineTask, + MeshyRiggedResult, + MeshyRiggingRequest, + MeshyTaskResponse, + MeshyTextToModelRequest, + MeshyTextureRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_string, +) +from folder_paths import get_output_directory + + +class MeshyTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextToModelNode", + display_name="Meshy: Text to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("style", options=["realistic", "sculpture"]), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.8}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + style: str, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=600) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextToModelRequest( + prompt=prompt, + art_style=style, + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + should_remesh=should_remesh["should_remesh"] == "true", + symmetry_mode=symmetry_mode, + pose_mode=pose_mode.lower(), + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRefineNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRefineNode", + display_name="Meshy: Refine Draft Model", + category="api node/3d/Meshy", + description="Refine a previously created draft model.", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) in addition to the base color. " + "Note: this should be set to false when using Sculpture style, " + "as Sculpture style generates its own set of PBR maps.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' may be used at the same time.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_pbr: bool, + texture_prompt: str, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if texture_prompt and texture_image is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + texture_image_url = None + if texture_prompt: + validate_string(texture_prompt, field_name="texture_prompt", max_length=600) + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRefineTask( + preview_task_id=meshy_task_id, + enable_pbr=enable_pbr, + texture_prompt=texture_prompt if texture_prompt else None, + texture_image_url=texture_image_url, + ai_model=model, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyImageToModelNode", + display_name="Meshy: Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 1.2, "false": 0.8}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyImageToModelRequest( + image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0], + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyMultiImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyMultiImageToModelNode", + display_name="Meshy: Multi-Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=2, max=4), + ), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 0.6, "false": 0.2}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + images: IO.Autogrow.Type, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyMultiImageToModelRequest( + image_urls=await upload_images_to_comfyapi( + cls, list(images.values()), wait_label="Uploading base images" + ), + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRigModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRigModelNode", + display_name="Meshy: Rig Model", + category="api node/3d/Meshy", + description="Provides a rigged character in standard formats. " + "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " + "or humanoid assets with unclear limb and body structure.", + inputs=[ + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Float.Input( + "height_meters", + min=0.1, + max=15.0, + default=1.7, + tooltip="The approximate height of the character model in meters. " + "This aids in scaling and rigging accuracy.", + ), + IO.Image.Input( + "texture_image", + tooltip="The model's UV-unwrapped base color texture image.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), + ) + + @classmethod + async def execute( + cls, + meshy_task_id: str, + height_meters: float, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + texture_image_url = None + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRiggingRequest( + input_task_id=meshy_task_id, + height_meters=height_meters, + texture_image_url=texture_image_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"), + response_model=MeshyRiggedResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio( + result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file) + ) + return IO.NodeOutput(model_file, response.result) + + +class MeshyAnimateModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyAnimateModelNode", + display_name="Meshy: Animate Model", + category="api node/3d/Meshy", + description="Apply a specific animation action to a previously rigged character.", + inputs=[ + IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), + IO.Int.Input( + "action_id", + default=0, + min=0, + max=696, + tooltip="Visit https://docs.meshy.ai/en/api/animation-library for a list of available values.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.12}""", + ), + ) + + @classmethod + async def execute( + cls, + rig_task_id: str, + action_id: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyAnimationRequest( + rig_task_id=rig_task_id, + action_id=action_id, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"), + response_model=MeshyAnimationResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyTextureNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextureNode", + display_name="Meshy: Texture Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_original_uv", + default=True, + tooltip="Use the original UV of the model instead of generating new UVs. " + "When enabled, Meshy preserves existing textures from the uploaded model. " + "If the model has no original UV, the quality of the output might not be as good.", + ), + IO.Boolean.Input("pbr", default=False), + IO.String.Input( + "text_style_prompt", + default="", + multiline=True, + tooltip="Describe your desired texture style of the object using text. Maximum 600 characters." + "Maximum 600 characters. Cannot be used at the same time as 'image_style'.", + ), + IO.Image.Input( + "image_style", + optional=True, + tooltip="A 2d image to guide the texturing process. " + "Can not be used at the same time with 'text_style_prompt'.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_original_uv: bool, + pbr: bool, + text_style_prompt: str, + image_style: Input.Image | None = None, + ) -> IO.NodeOutput: + if text_style_prompt and image_style is not None: + raise ValueError("text_style_prompt and image_style cannot be used at the same time") + if not text_style_prompt and image_style is None: + raise ValueError("Either text_style_prompt or image_style is required") + image_style_url = None + if image_style is not None: + image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextureRequest( + input_task_id=meshy_task_id, + ai_model=model, + enable_original_uv=enable_original_uv, + enable_pbr=pbr, + text_style_prompt=text_style_prompt if text_style_prompt else None, + image_style_url=image_style_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MeshyTextToModelNode, + MeshyRefineNode, + MeshyImageToModelNode, + MeshyMultiImageToModelNode, + MeshyRigModelNode, + MeshyAnimateModelNode, + MeshyTextureNode, + ] + + +async def comfy_entrypoint() -> MeshyExtension: + return MeshyExtension() diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index cea0d1203..2794be35c 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -43,7 +43,7 @@ class UploadResponse(BaseModel): async def upload_images_to_comfyapi( cls: type[IO.ComfyNode], - image: torch.Tensor, + image: torch.Tensor | list[torch.Tensor], *, max_images: int = 8, mime_type: str | None = None, @@ -55,15 +55,28 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ + tensors: list[torch.Tensor] = [] + if isinstance(image, list): + for img in image: + is_batch = len(img.shape) > 3 + if is_batch: + tensors.extend(img[i] for i in range(img.shape[0])) + else: + tensors.append(img) + else: + is_batch = len(image.shape) > 3 + if is_batch: + tensors.extend(image[i] for i in range(image.shape[0])) + else: + tensors.append(image) + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - num_to_upload = min(batch_len, max_images) + num_to_upload = min(len(tensors), max_images) batch_start_ts = time.monotonic() for idx in range(num_to_upload): - tensor = image[idx] if is_batch else image + tensor = tensors[idx] img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label diff --git a/nodes.py b/nodes.py index aa8572446..f19d5fd1c 100644 --- a/nodes.py +++ b/nodes.py @@ -2401,6 +2401,7 @@ async def init_builtin_api_nodes(): "nodes_sora.py", "nodes_topaz.py", "nodes_tripo.py", + "nodes_meshy.py", "nodes_moonvalley.py", "nodes_rodin.py", "nodes_gemini.py", From 80441eb15e807aa280fb462cbb43d14191344ba4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:53:16 -0800 Subject: [PATCH 138/181] utils: fix lanczos grayscale upscaling (#11873) --- comfy/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index fac13f128..2e33a4258 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -929,7 +929,9 @@ def bislerp(samples, width, height): return result.to(orig_dtype) def lanczos(samples, width, height): - images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + #the below API is strict and expects grayscale to be squeezed + samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) + images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) From be518db5a7daa6010fb1c312c0832b9833a71d10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:54:04 +0200 Subject: [PATCH 139/181] Remove extraneous clip missing warnings when loading LTX2 embeddings_connector weights (#11874) --- comfy/text_encoders/lt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 776e25e97..c33c77db7 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,8 +118,9 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - - return self.load_state_dict(sdo, strict=False) + missing, unexpected = self.load_state_dict(sdo, strict=False) + missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model + return (missing, unexpected) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 3b832231bb81024d80bbe31b7d7e51e07b633beb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 Jan 2026 07:33:15 -0800 Subject: [PATCH 140/181] Flux2 Klein support. (#11890) --- comfy/sd.py | 15 +++++++-- comfy/text_encoders/flux.py | 59 +++++++++++++++++++++++++++++++++++- comfy/text_encoders/llama.py | 31 +++++++++++++++++++ 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..77700dfd3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1014,6 +1014,7 @@ class CLIPType(Enum): KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 NEWBIE = 24 + FLUX2 = 25 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1046,6 +1047,7 @@ class TEModel(Enum): QWEN3_2B = 17 GEMMA_3_12B = 18 JINA_CLIP_2 = 19 + QWEN3_8B = 20 def detect_te_model(sd): @@ -1089,6 +1091,8 @@ def detect_te_model(sd): return TEModel.QWEN3_4B elif weight.shape[0] == 2048: return TEModel.QWEN3_2B + elif weight.shape[0] == 4096: + return TEModel.QWEN3_8B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1214,11 +1218,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) elif te_model == TEModel.QWEN3_4B: - clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer + else: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.QWEN3_8B: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 21d93d757..4075afca4 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -3,7 +3,7 @@ import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast, LlamaTokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer import torch import os import json @@ -172,3 +172,60 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): model_options["num_layers"] = 30 super().__init__(device=device, dtype=dtype, model_options=model_options) return Flux2TEModel_ + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class KleinTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): + if name == "qwen3_4b": + tokenizer = Qwen3Tokenizer + elif name == "qwen3_8b": + tokenizer = Qwen3Tokenizer8B + + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class KleinTokenizer8B(KleinTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name) + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_8BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"): + if model_type == "qwen3_4b": + model = Qwen3_4BModel + elif model_type == "qwen3_8b": + model = Qwen3_8BModel + + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) + return Flux2TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 76731576b..331a30f61 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -99,6 +99,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_8BConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Ovis25_2BConfig: vocab_size: int = 151936 @@ -628,6 +650,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_8B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_8BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Ovis25_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 8f40b43e0204d5b9780f3e9618e140e929e80594 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 15 Jan 2026 10:57:35 -0500 Subject: [PATCH 141/181] ComfyUI v0.9.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0c9871e35..dbb57b4e5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.1" +__version__ = "0.9.2" diff --git a/pyproject.toml b/pyproject.toml index dc52218b4..9ea73da05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.1" +version = "0.9.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 12918a5f789d11c7d3c9d9f732891337740fe96f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 16 Jan 2026 03:08:21 +0800 Subject: [PATCH 142/181] chore: update workflow templates to v0.8.7 (#11896) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8650d28ec..624cd067b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.4 +comfyui-workflow-templates==0.8.7 comfyui-embedded-docs==0.4.0 torch torchsde From 6125b3a5e7215bf01874e402525552a7f5657a41 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 16 Jan 2026 05:12:13 +0800 Subject: [PATCH 143/181] Update workflow templates to v0.8.10 (#11899) * chore: update workflow templates to v0.8.9 * Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 624cd067b..996701550 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.7 +comfyui-workflow-templates==0.8.10 comfyui-embedded-docs==0.4.0 torch torchsde From 4c816d5c698dafaa31f8fc2c08ab1d81f9bc3239 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:06:40 -0800 Subject: [PATCH 144/181] Adjust memory usage factor calculation for flux2 klein. (#11900) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2c4c6b8fc..c8a7f6efb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -763,7 +763,7 @@ class Flux2(Flux): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604) def get_model(self, state_dict, prefix="", device=None): out = model_base.Flux2(self, device=device) From 732b707397922dbbec5ed04ecca3c773c878c64e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 15 Jan 2026 20:15:15 -0800 Subject: [PATCH 145/181] Added try-except around seed_assets call in get_object_info with a logging statement (#11901) --- server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index da2baefd4..04a577488 100644 --- a/server.py +++ b/server.py @@ -686,7 +686,10 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - seed_assets(["models"]) + try: + seed_assets(["models"]) + except Exception as e: + logging.error(f"Failed to seed assets: {e}") with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: From 9125613b53fc6af219d5a3db1d5b202ccc3f41b3 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 16 Jan 2026 08:09:07 +0200 Subject: [PATCH 146/181] feat(api-nodes): extend ByteDance nodes with seedance-1-5-pro model (#11871) --- comfy_api_nodes/apis/bytedance_api.py | 7 ++ comfy_api_nodes/nodes_bytedance.py | 104 +++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index b8c2f618b..400648cca 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -65,11 +65,13 @@ class TaskImageContent(BaseModel): class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(...) class Image2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + generate_audio: bool | None = Field(...) class TaskCreationResponse(BaseModel): @@ -141,4 +143,9 @@ VIDEO_TASKS_EXECUTION_TIME = { "720p": 65, "1080p": 100, }, + "seedance-1-5-pro-251215": { + "480p": 80, + "720p": 100, + "1080p": 150, + }, } diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f09a4a0ed..9cb1ca004 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -477,7 +477,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-t2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -528,6 +533,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -552,7 +563,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) @@ -567,7 +581,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ) return await process_video_task( cls, - payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, + ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -584,7 +602,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-i2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -639,6 +662,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -664,7 +693,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) @@ -686,6 +718,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -703,7 +736,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( @@ -762,6 +795,12 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -788,7 +827,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): @@ -821,6 +863,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -896,7 +939,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, + ), ) @classmethod @@ -967,10 +1044,15 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: PRICE_BADGE_VIDEO = IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]), expr=""" ( $priceByModel := { + "seedance-1-5-pro": { + "480p":[0.12,0.12], + "720p":[0.26,0.26], + "1080p":[0.58,0.59] + }, "seedance-1-0-pro": { "480p":[0.23,0.24], "720p":[0.51,0.56], @@ -989,6 +1071,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( }; $model := widgets.model; $modelKey := + $contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" : $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : "seedance-1-0-lite"; @@ -1002,11 +1085,12 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( $min10s := $baseRange[0]; $max10s := $baseRange[1]; $scale := widgets.duration / 10; - $minCost := $min10s * $scale; - $maxCost := $max10s * $scale; + $audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1; + $minCost := $min10s * $scale * $audioMultiplier; + $maxCost := $max10s * $scale * $audioMultiplier; ($minCost = $maxCost) - ? {"type":"usd","usd": $minCost} - : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ? {"type":"usd","usd": $minCost, "format": { "approximate": true }} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }} ) """, ) From 0c6b36c6ac1c34515cdf28f777a63074cd6d563d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 17 Jan 2026 06:22:50 +0800 Subject: [PATCH 147/181] chore: update workflow templates to v0.8.11 (#11918) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 996701550..3876274f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.10 +comfyui-workflow-templates==0.8.11 comfyui-embedded-docs==0.4.0 torch torchsde From 7ac999bf3069b06648a749212f59237080a75591 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:02:28 -0800 Subject: [PATCH 148/181] Add image sizes to clip vision outputs. (#11923) --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 66f2a9d9c..b28bf636c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -66,6 +66,7 @@ class ClipVisionModel(): outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0] if self.return_all_hidden_states: all_hs = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = all_hs[:, -2] From 00c775950aec5c563f532c8db08dae5e6adc24eb Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Sun, 18 Jan 2026 01:18:04 +0000 Subject: [PATCH 149/181] Update readme rdna3 nightly url (#11937) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e25f3cda7..123cc9472 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ These have less hardware support than the builds above but they work on windows. RDNA 3 (RX 7000 series): -```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/``` RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): From 0fd10ffa09588e0fc7f576ab7d0c93e97ad5fbb0 Mon Sep 17 00:00:00 2001 From: Theephop <144770658+TheephopWS@users.noreply.github.com> Date: Sun, 18 Jan 2026 09:18:24 +0800 Subject: [PATCH 150/181] fix: use .cpu() for waveform conversion in AudioFrame creation (#11787) --- comfy_api/latest/_input_impl/video_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index ea35c6062..1405d0b81 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -374,7 +374,7 @@ class VideoFromComponents(VideoInput): if audio_stream and self.__components.audio: waveform = self.__components.audio['waveform'] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) From 190c4416cce3b3b97b628935e001d796d565bfc9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 17 Jan 2026 18:20:35 -0800 Subject: [PATCH 151/181] Bump comfy-kitchen dependency to version 0.2.7 (#11941) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3876274f9..622256973 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.6 +comfy-kitchen>=0.2.7 #non essential dependencies: kornia>=0.7.1 From ac26065e6125871e2a742db6960f183fa037a75d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:52:45 +0200 Subject: [PATCH 152/181] chore(api-nodes): remove non-used; extract model to separate files (#11927) * chore(api-nodes): remove non-used; extract model to separate files * chore(api-nodes): remove non-needed prefix in filenames --- comfy_api_nodes/README.md | 65 ---- comfy_api_nodes/apis/{bfl_api.py => bfl.py} | 0 .../apis/{bytedance_api.py => bytedance.py} | 0 .../apis/{gemini_api.py => gemini.py} | 0 comfy_api_nodes/apis/ideogram.py | 292 +++++++++++++++++ .../apis/{kling_api.py => kling.py} | 0 comfy_api_nodes/apis/{luma_api.py => luma.py} | 0 .../apis/{minimax_api.py => minimax.py} | 0 comfy_api_nodes/apis/moonvalley.py | 152 +++++++++ comfy_api_nodes/apis/openai.py | 170 ++++++++++ comfy_api_nodes/apis/openai_api.py | 52 --- .../apis/{pixverse_api.py => pixverse.py} | 0 .../apis/{recraft_api.py => recraft.py} | 0 .../apis/{rodin_api.py => rodin.py} | 0 comfy_api_nodes/apis/runway.py | 127 ++++++++ .../apis/{stability_api.py => stability.py} | 0 .../apis/{topaz_api.py => topaz.py} | 4 +- .../apis/{tripo_api.py => tripo.py} | 0 comfy_api_nodes/apis/{veo_api.py => veo.py} | 0 comfy_api_nodes/mapper_utils.py | 116 ------- comfy_api_nodes/nodes_bfl.py | 2 +- comfy_api_nodes/nodes_bytedance.py | 2 +- comfy_api_nodes/nodes_gemini.py | 2 +- comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 2 +- comfy_api_nodes/nodes_openai.py | 86 ++--- comfy_api_nodes/nodes_pixverse.py | 2 +- comfy_api_nodes/nodes_recraft.py | 2 +- comfy_api_nodes/nodes_rodin.py | 2 +- comfy_api_nodes/nodes_runway.py | 2 +- comfy_api_nodes/nodes_stability.py | 2 +- comfy_api_nodes/nodes_topaz.py | 55 ++-- comfy_api_nodes/nodes_tripo.py | 2 +- comfy_api_nodes/nodes_veo2.py | 2 +- comfy_api_nodes/redocly-dev.yaml | 10 - comfy_api_nodes/redocly.yaml | 10 - .../comfy_api_nodes_test/mapper_utils_test.py | 297 ------------------ 40 files changed, 825 insertions(+), 641 deletions(-) delete mode 100644 comfy_api_nodes/README.md rename comfy_api_nodes/apis/{bfl_api.py => bfl.py} (100%) rename comfy_api_nodes/apis/{bytedance_api.py => bytedance.py} (100%) rename comfy_api_nodes/apis/{gemini_api.py => gemini.py} (100%) create mode 100644 comfy_api_nodes/apis/ideogram.py rename comfy_api_nodes/apis/{kling_api.py => kling.py} (100%) rename comfy_api_nodes/apis/{luma_api.py => luma.py} (100%) rename comfy_api_nodes/apis/{minimax_api.py => minimax.py} (100%) create mode 100644 comfy_api_nodes/apis/moonvalley.py create mode 100644 comfy_api_nodes/apis/openai.py delete mode 100644 comfy_api_nodes/apis/openai_api.py rename comfy_api_nodes/apis/{pixverse_api.py => pixverse.py} (100%) rename comfy_api_nodes/apis/{recraft_api.py => recraft.py} (100%) rename comfy_api_nodes/apis/{rodin_api.py => rodin.py} (100%) create mode 100644 comfy_api_nodes/apis/runway.py rename comfy_api_nodes/apis/{stability_api.py => stability.py} (100%) rename comfy_api_nodes/apis/{topaz_api.py => topaz.py} (97%) rename comfy_api_nodes/apis/{tripo_api.py => tripo.py} (100%) rename comfy_api_nodes/apis/{veo_api.py => veo.py} (100%) delete mode 100644 comfy_api_nodes/mapper_utils.py delete mode 100644 comfy_api_nodes/redocly-dev.yaml delete mode 100644 comfy_api_nodes/redocly.yaml delete mode 100644 tests-unit/comfy_api_nodes_test/mapper_utils_test.py diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md deleted file mode 100644 index f56d6c860..000000000 --- a/comfy_api_nodes/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# ComfyUI API Nodes - -## Introduction - -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). - -## Development - -While developing, you should be testing against the Staging environment. To test against staging: - -**Install ComfyUI_frontend** - -Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. - -> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. - -```bash -python run main.py --comfy-api-base https://stagingapi.comfy.org -``` - -To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. - -API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. - -### Redocly Instructions - -**Tip** -When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. - -Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. - -```bash -# Download the OpenAPI file from staging server. -curl -o openapi.yaml https://stagingapi.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` - - -# Merging to Master - -Before merging to comfyanonymous/ComfyUI master, follow these steps: - -1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. -1. Make sure the ComfyUI API is deployed to prod with your changes. -1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. - -```bash -# Download the OpenAPI file from prod server. -curl -o openapi.yaml https://api.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl.py similarity index 100% rename from comfy_api_nodes/apis/bfl_api.py rename to comfy_api_nodes/apis/bfl.py diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance.py similarity index 100% rename from comfy_api_nodes/apis/bytedance_api.py rename to comfy_api_nodes/apis/bytedance.py diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini.py similarity index 100% rename from comfy_api_nodes/apis/gemini_api.py rename to comfy_api_nodes/apis/gemini.py diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py new file mode 100644 index 000000000..737e18e3b --- /dev/null +++ b/comfy_api_nodes/apis/ideogram.py @@ -0,0 +1,292 @@ +from enum import Enum +from typing import Optional, List, Dict, Any, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel, StrictBytes + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[int] = Field( + None, + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class Datum1(BaseModel): + is_image_safe: Optional[bool] = None + prompt: Optional[str] = None + resolution: Optional[str] = None + seed: Optional[int] = None + style_type: Optional[str] = None + url: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + +class IdeogramV3ReframeRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' + + +class RenderingSpeed(str, Enum): + DEFAULT = 'DEFAULT' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + rendering_speed: RenderingSpeed + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) + + +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling.py similarity index 100% rename from comfy_api_nodes/apis/kling_api.py rename to comfy_api_nodes/apis/kling.py diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma.py similarity index 100% rename from comfy_api_nodes/apis/luma_api.py rename to comfy_api_nodes/apis/luma.py diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax.py similarity index 100% rename from comfy_api_nodes/apis/minimax_api.py rename to comfy_api_nodes/apis/minimax.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py new file mode 100644 index 000000000..7ec7a4ade --- /dev/null +++ b/comfy_api_nodes/apis/moonvalley.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, StrictBytes + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py new file mode 100644 index 000000000..b85ef252b --- /dev/null +++ b/comfy_api_nodes/apis/openai.py @@ -0,0 +1,170 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = Field(None) + text_tokens: int | None = Field(None) + + +class Usage(BaseModel): + input_tokens: int | None = Field(None) + input_tokens_details: InputTokensDetails | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = Field(None) + usage: Usage | None = Field(None) + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") + + +class ModelResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + top_p: float | None = Field( + 1, + description="Controls diversity of the response via nucleus sampling", + ge=0.0, + le=1.0, + ) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + previous_response_id: str | None = Field(None) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description="The number of reasoning tokens.") + + +class CachedTokensDetails(BaseModel): + cached_tokens: int = Field( + ..., + description="The number of tokens that were retrieved from the cache.", + ) + + +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description="The number of input tokens.") + input_tokens_details: CachedTokensDetails = Field(...) + output_tokens: int = Field(..., description="The number of output tokens.") + output_tokens_details: OutputTokensDetails = Field(...) + total_tokens: int = Field(..., description="The total number of tokens used.") + + +class InputTextContent(BaseModel): + text: str = Field(..., description="The text input to the model.") + type: str = Field("input_text") + + +class OutputContent(BaseModel): + type: str = Field(..., description="The type of output content") + text: str | None = Field(None, description="The text content") + data: str | None = Field(None, description="Base64-encoded audio data") + transcript: str | None = Field(None, description="Transcript of the audio") + + +class OutputMessage(BaseModel): + type: str = Field(..., description="The type of output item") + content: list[OutputContent] | None = Field(None, description="The content of the message") + role: str | None = Field(None, description="The role of the message") + + +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: float | None = Field( + None, + description="Unix timestamp (in seconds) of when this Response was created.", + ) + error: ResponseError | None = Field(None) + id: str | None = Field(None, description="Unique identifier for this Response.") + object: str | None = Field(None, description="The object type of this resource - always set to `response`.") + output: list[OutputMessage] | None = Field(None) + parallel_tool_calls: bool | None = Field(True) + status: str | None = Field( + None, + description="One of `completed`, `failed`, `in_progress`, or `incomplete`.", + ) + usage: ResponseUsage | None = Field(None) + + +class InputImageContent(BaseModel): + detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.") + file_id: str | None = Field(None) + image_url: str | None = Field(None) + type: str = Field(..., description="The type of the input item. Always `input_image`.") + + +class InputFileContent(BaseModel): + file_data: str | None = Field(None) + file_id: str | None = Field(None) + filename: str | None = Field(None, description="The name of the file to be sent to the model.") + type: str = Field(..., description="The type of the input item. Always `input_file`.") + + +class InputMessage(BaseModel): + content: list[InputTextContent | InputImageContent | InputFileContent] = Field( + ..., + description="A list of one or many input items to the model, containing different content types.", + ) + role: str | None = Field(None) + type: str | None = Field(None) + + +class OpenAICreateResponse(ModelResponseProperties, ResponseProperties): + include: str | None = Field(None) + input: list[InputMessage] = Field(...) + parallel_tool_calls: bool | None = Field( + True, description="Whether to allow the model to run tool calls in parallel." + ) + store: bool | None = Field( + True, + description="Whether to store the generated model response for later retrieval via API.", + ) + stream: bool | None = Field(False) + usage: ResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py deleted file mode 100644 index ae5bb2673..000000000 --- a/comfy_api_nodes/apis/openai_api.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel, Field - - -class Datum2(BaseModel): - b64_json: str | None = Field(None, description="Base64 encoded image data") - revised_prompt: str | None = Field(None, description="Revised prompt") - url: str | None = Field(None, description="URL of the image") - - -class InputTokensDetails(BaseModel): - image_tokens: int | None = None - text_tokens: int | None = None - - -class Usage(BaseModel): - input_tokens: int | None = None - input_tokens_details: InputTokensDetails | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class OpenAIImageGenerationResponse(BaseModel): - data: list[Datum2] | None = None - usage: Usage | None = None - - -class OpenAIImageEditRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str = Field(...) - moderation: str | None = Field(None) - n: int | None = Field(None, description="The number of images to generate") - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - size: str | None = Field(None, description="Size of the output image") - - -class OpenAIImageGenerationRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str | None = Field(None) - moderation: str | None = Field(None) - n: int | None = Field( - None, - description="The number of images to generate.", - ) - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="The quality of the generated image") - size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse.py similarity index 100% rename from comfy_api_nodes/apis/pixverse_api.py rename to comfy_api_nodes/apis/pixverse.py diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft.py similarity index 100% rename from comfy_api_nodes/apis/recraft_api.py rename to comfy_api_nodes/apis/recraft.py diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin.py similarity index 100% rename from comfy_api_nodes/apis/rodin_api.py rename to comfy_api_nodes/apis/rodin.py diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py new file mode 100644 index 000000000..df6f2b845 --- /dev/null +++ b/comfy_api_nodes/apis/runway.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, List, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( + None, + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', + ge=0.0, + le=1.0, + ) + status: RunwayTaskStatusEnum + + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' + ) + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + + +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) + + +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability.py similarity index 100% rename from comfy_api_nodes/apis/stability_api.py rename to comfy_api_nodes/apis/stability.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz.py similarity index 97% rename from comfy_api_nodes/apis/topaz_api.py rename to comfy_api_nodes/apis/topaz.py index 4d9e62e72..a9e6235a7 100644 --- a/comfy_api_nodes/apis/topaz_api.py +++ b/comfy_api_nodes/apis/topaz.py @@ -41,7 +41,7 @@ class Resolution(BaseModel): height: int = Field(...) -class CreateCreateVideoRequestSource(BaseModel): +class CreateVideoRequestSource(BaseModel): container: str = Field(...) size: int = Field(..., description="Size of the video file in bytes") duration: int = Field(..., description="Duration of the video file in seconds") @@ -89,7 +89,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): - source: CreateCreateVideoRequestSource = Field(...) + source: CreateVideoRequestSource = Field(...) filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo.py similarity index 100% rename from comfy_api_nodes/apis/tripo_api.py rename to comfy_api_nodes/apis/tripo.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo.py similarity index 100% rename from comfy_api_nodes/apis/veo_api.py rename to comfy_api_nodes/apis/veo.py diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py deleted file mode 100644 index 6fab8f4bb..000000000 --- a/comfy_api_nodes/mapper_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from enum import Enum - -from pydantic.fields import FieldInfo -from pydantic import BaseModel -from pydantic_core import PydanticUndefined - -from comfy.comfy_types.node_typing import IO, InputTypeOptions - -NodeInput = tuple[IO, InputTypeOptions] - - -def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: - config = {} - if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: - config["default"] = field_info.default - if hasattr(field_info, "description") and field_info.description is not None: - config["tooltip"] = field_info.description - return config - - -def _get_number_constraints_config(field_info: FieldInfo) -> dict: - config = {} - if hasattr(field_info, "metadata"): - metadata = field_info.metadata - for constraint in metadata: - if hasattr(constraint, "ge"): - config["min"] = constraint.ge - if hasattr(constraint, "le"): - config["max"] = constraint.le - if hasattr(constraint, "multiple_of"): - config["step"] = constraint.multiple_of - return config - - -def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.IMAGE, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.STRING, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.FLOAT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.INT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_combo_input( - field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs -) -> NodeInput: - combo_config = {} - if enum_type is not None: - combo_config["options"] = [option.value for option in enum_type] - combo_config = { - **combo_config, - **_create_base_config(field_info), - **kwargs, - } - return IO.COMBO, combo_config - - -def model_field_to_node_input( - input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs -) -> NodeInput: - """ - Maps a field from a Pydantic model to a Comfy node input. - - Args: - input_type: The type of the input. - base_model: The Pydantic model to map the field from. - field_name: The name of the field to map. - **kwargs: Additional key/values to include in the input options. - - Note: - For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. - - Example: - >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) - >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) - >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) - """ - field_info: FieldInfo = base_model.model_fields[field_name] - result: NodeInput - - if input_type == IO.IMAGE: - result = _model_field_to_image_input(field_info, **kwargs) - elif input_type == IO.STRING: - result = _model_field_to_string_input(field_info, **kwargs) - elif input_type == IO.FLOAT: - result = _model_field_to_float_input(field_info, **kwargs) - elif input_type == IO.INT: - result = _model_field_to_int_input(field_info, **kwargs) - elif input_type == IO.COMBO: - result = _model_field_to_combo_input(field_info, **kwargs) - else: - message = f"Invalid input type: {input_type}" - raise ValueError(message) - - return result diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 76021ef7f..61c3b4503 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bfl_api import ( +from comfy_api_nodes.apis.bfl import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 9cb1ca004..486801150 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bytedance_api import ( +from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index a2daea50a..3b31caa7b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -14,7 +14,7 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input, Types -from comfy_api_nodes.apis.gemini_api import ( +from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 827b3523a..feaf7a858 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.ideogram import ( IdeogramGenerateRequest, IdeogramGenerateResponse, ImageRequest, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 05dde88b1..3ec71530b 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -49,7 +49,7 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.kling_api import ( +from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, MotionControlRequest, OmniImageParamImage, diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 95cb442e5..9ed6cd299 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.luma_api import ( +from comfy_api_nodes.apis.luma import ( LumaAspectRatio, LumaCharacterRef, LumaConceptChain, diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 43a15d50d..b5d0b461f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.minimax_api import ( +from comfy_api_nodes.apis.minimax import ( MinimaxFileRetrieveResponse, MiniMaxModel, MinimaxTaskResultResponse, diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 769b171b7..08315fa2b 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -3,7 +3,7 @@ import logging from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.moonvalley import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoRequest, diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 2f144c5c3..a12acc06b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -10,24 +10,18 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( - CreateModelResponseProperties, - Detail, - InputContent, +from comfy_api_nodes.apis.openai import ( InputFileContent, InputImageContent, InputMessage, - InputMessageContentList, InputTextContent, - Item, + ModelResponseProperties, OpenAICreateResponse, - OpenAIResponse, - OutputContent, -) -from comfy_api_nodes.apis.openai_api import ( OpenAIImageEditRequest, OpenAIImageGenerationRequest, OpenAIImageGenerationResponse, + OpenAIResponse, + OutputContent, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -266,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -384,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -500,8 +494,8 @@ class OpenAIGPTImage1(IO.ComfyNode): files = [] batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i: i + 1] - scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -523,7 +517,7 @@ class OpenAIGPTImage1(IO.ComfyNode): rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) @@ -696,29 +690,23 @@ class OpenAIChatNode(IO.ComfyNode): ) @classmethod - def get_message_content_from_response( - cls, response: OpenAIResponse - ) -> list[OutputContent]: + def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: - if output.root.type == "message": - return output.root.content + if output.type == "message": + return output.content raise TypeError("No output message found in response") @classmethod - def get_text_from_message_content( - cls, message_content: list[OutputContent] - ) -> str: + def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str: """Extract text content from message content.""" for content_item in message_content: - if content_item.root.type == "output_text": - return str(content_item.root.text) + if content_item.type == "output_text": + return str(content_item.text) return "No text output found in response" @classmethod - def tensor_to_input_image_content( - cls, image: torch.Tensor, detail_level: Detail = "auto" - ) -> InputImageContent: + def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( detail=detail_level, @@ -732,9 +720,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, image: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - ) -> InputMessageContentList: + ) -> list[InputTextContent | InputImageContent | InputFileContent]: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ + content_list: list[InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -746,13 +734,9 @@ class OpenAIChatNode(IO.ComfyNode): type="input_image", ) ) - if files is not None: content_list.extend(files) - - return InputMessageContentList( - root=content_list, - ) + return content_list @classmethod async def execute( @@ -762,7 +746,7 @@ class OpenAIChatNode(IO.ComfyNode): model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - advanced_options: CreateModelResponseProperties | None = None, + advanced_options: ModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -773,36 +757,28 @@ class OpenAIChatNode(IO.ComfyNode): response_model=OpenAIResponse, data=OpenAICreateResponse( input=[ - Item( - root=InputMessage( - content=cls.create_input_message_contents( - prompt, images, files - ), - role="user", - ) + InputMessage( + content=cls.create_input_message_contents(prompt, images, files), + role="user", ), ], store=True, stream=False, model=model, previous_response_id=None, - **( - advanced_options.model_dump(exclude_none=True) - if advanced_options - else {} - ), + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), ), ) response_id = create_response.id # Get result output result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"] - ) + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"], + ) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) @@ -923,7 +899,7 @@ class OpenAIChatConfig(IO.ComfyNode): remove depending on model choice. """ return IO.NodeOutput( - CreateModelResponseProperties( + ModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 86ddb3ab9..e17a24ae7 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.pixverse_api import ( +from comfy_api_nodes.apis.pixverse import ( PixverseTextVideoRequest, PixverseImageVideoRequest, PixverseTransitionVideoRequest, diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 05dc151ad..c01bcaece 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,7 +8,7 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.recraft_api import ( +from comfy_api_nodes.apis.recraft import ( RecraftColor, RecraftColorChain, RecraftControls, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index b4420cb93..3ffdc8b90 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -14,7 +14,7 @@ from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin_api import ( +from comfy_api_nodes.apis.rodin import ( Rodin3DGenerateRequest, Rodin3DGenerateResponse, Rodin3DCheckStatusRequest, diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index d19fdb365..573170ba2 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -16,7 +16,7 @@ from enum import Enum from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.runway import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 5c48c1f1e..5665109cf 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -3,7 +3,7 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.stability_api import ( +from comfy_api_nodes.apis.stability import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, StabilityAsyncResponse, diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 9dc5f45bc..c052e7656 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,7 +5,24 @@ import aiohttp from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.apis.topaz import ( + CreateVideoRequest, + CreateVideoRequestSource, + CreateVideoResponse, + ImageAsyncTaskResponse, + ImageDownloadResponse, + ImageEnhanceRequest, + ImageStatusResponse, + OutputInformationVideo, + Resolution, + VideoAcceptResponse, + VideoCompleteUploadRequest, + VideoCompleteUploadRequestPart, + VideoCompleteUploadResponse, + VideoEnhancementFilter, + VideoFrameInterpolationFilter, + VideoStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -153,13 +170,13 @@ class TopazImageEnhance(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") download_url = await upload_images_to_comfyapi( - cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096 + cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=topaz_api.ImageAsyncTaskResponse, - data=topaz_api.ImageEnhanceRequest( + response_model=ImageAsyncTaskResponse, + data=ImageEnhanceRequest( model=model, prompt=prompt, subject_detection=subject_detection, @@ -181,7 +198,7 @@ class TopazImageEnhance(IO.ComfyNode): await poll_op( cls, poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=topaz_api.ImageStatusResponse, + response_model=ImageStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, @@ -193,7 +210,7 @@ class TopazImageEnhance(IO.ComfyNode): results = await sync_op( cls, ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=topaz_api.ImageDownloadResponse, + response_model=ImageDownloadResponse, monitor_progress=False, ) return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) @@ -331,7 +348,7 @@ class TopazVideoEnhance(IO.ComfyNode): if target_height % 2 != 0: target_height += 1 filters.append( - topaz_api.VideoEnhancementFilter( + VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), @@ -340,7 +357,7 @@ class TopazVideoEnhance(IO.ComfyNode): if interpolation_enabled: target_frame_rate = interpolation_frame_rate filters.append( - topaz_api.VideoFrameInterpolationFilter( + VideoFrameInterpolationFilter( model=interpolation_model, slowmo=interpolation_slowmo, fps=interpolation_frame_rate, @@ -351,19 +368,19 @@ class TopazVideoEnhance(IO.ComfyNode): initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=topaz_api.CreateVideoResponse, - data=topaz_api.CreateVideoRequest( - source=topaz_api.CreateCreateVideoRequestSource( + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), frameCount=video.get_frame_count(), frameRate=src_frame_rate, - resolution=topaz_api.Resolution(width=src_width, height=src_height), + resolution=Resolution(width=src_width, height=src_height), ), filters=filters, - output=topaz_api.OutputInformationVideo( - resolution=topaz_api.Resolution(width=target_width, height=target_height), + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), frameRate=target_frame_rate, audioCodec="AAC", audioTransfer="Copy", @@ -379,7 +396,7 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/accept", method="PATCH", ), - response_model=topaz_api.VideoAcceptResponse, + response_model=VideoAcceptResponse, wait_label="Preparing upload", final_label_on_success="Upload started", ) @@ -402,10 +419,10 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", method="PATCH", ), - response_model=topaz_api.VideoCompleteUploadResponse, - data=topaz_api.VideoCompleteUploadRequest( + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( uploadResults=[ - topaz_api.VideoCompleteUploadRequestPart( + VideoCompleteUploadRequestPart( partNum=1, eTag=upload_etag, ), @@ -417,7 +434,7 @@ class TopazVideoEnhance(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=topaz_api.VideoStatusResponse, + response_model=VideoStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index aa790143d..5abf27b4d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.tripo_api import ( +from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, TripoConvertModelRequest, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index c14d6ad68..2a202fc3b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -4,7 +4,7 @@ from io import BytesIO from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.veo_api import ( +from comfy_api_nodes.apis.veo import ( VeoGenVidPollRequest, VeoGenVidPollResponse, VeoGenVidRequest, diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml deleted file mode 100644 index d9e3cab70..000000000 --- a/comfy_api_nodes/redocly-dev.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. -# This is used for development purposes to generate stubs for unreleased API endpoints. -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes'] - matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml deleted file mode 100644 index d102345b1..000000000 --- a/comfy_api_nodes/redocly.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. - -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes', 'Released'] - matchStrategy: all diff --git a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py deleted file mode 100644 index 69488f691..000000000 --- a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Optional -from enum import Enum - -from pydantic import BaseModel, Field - -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.mapper_utils import model_field_to_node_input - - -def test_model_field_to_float_input(): - """Tests mapping a float field with constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field( - default=0.5, - description="Flexibility in video generation", - ge=0.0, - le=1.0, - multiple_of=0.001, - ) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - "tooltip": "Flexibility in video generation", - "min": 0.0, - "max": 1.0, - "step": 0.001, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_float_input_no_constraints(): - """Tests mapping a float field with no constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field(default=0.5) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_int_input(): - """Tests mapping an int field with constraints.""" - - class ModelWithIntField(BaseModel): - num_frames: Optional[int] = Field( - default=10, - description="Number of frames to generate", - ge=1, - le=100, - multiple_of=1, - ) - - expected_output = ( - IO.INT, - { - "default": 10, - "tooltip": "Number of frames to generate", - "min": 1, - "max": 100, - "step": 1, - }, - ) - - actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input_multiline(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - "multiline": True, - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithStringField, "prompt", multiline=True - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input(): - """Tests mapping a combo field.""" - - class MockEnum(str, Enum): - option_1 = "option 1" - option_2 = "option 2" - option_3 = "option 3" - - class ModelWithComboField(BaseModel): - model_name: Optional[MockEnum] = Field("option 1", description="Model Name") - - expected_output = ( - IO.COMBO, - { - "options": ["option 1", "option 2", "option 3"], - "default": "option 1", - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input_no_options(): - """Tests mapping a combo field with no options.""" - - class ModelWithComboField(BaseModel): - model_name: Optional[str] = Field(description="Model Name") - - expected_output = ( - IO.COMBO, - { - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_image_input(): - """Tests mapping an image field.""" - - class ModelWithImageField(BaseModel): - image: Optional[str] = Field( - default=None, - description="An image for the video generation", - ) - - expected_output = ( - IO.IMAGE, - { - "default": None, - "tooltip": "An image for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_description(): - """Tests mapping a field with no description.""" - - class ModelWithNoDescriptionField(BaseModel): - field: Optional[str] = Field(default="default value") - - expected_output = ( - IO.STRING, - { - "default": "default value", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDescriptionField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_default(): - """Tests mapping a field with no default.""" - - class ModelWithNoDefaultField(BaseModel): - field: Optional[str] = Field(description="A field with no default") - - expected_output = ( - IO.STRING, - { - "tooltip": "A field with no default", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_metadata(): - """Tests mapping a field with no metadata or properties defined on the schema.""" - - class ModelWithNoMetadataField(BaseModel): - field: Optional[str] = Field() - - expected_output = ( - IO.STRING, - {}, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoMetadataField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_default_is_none(): - """ - Tests mapping a field with a default of `None`. - I.e., the default field should be included as the schema explicitly sets it to `None`. - """ - - class ModelWithNoneDefaultField(BaseModel): - field: Optional[str] = Field( - default=None, description="A field with a default of None" - ) - - expected_output = ( - IO.STRING, - { - "default": None, - "tooltip": "A field with a default of None", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoneDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] From f7ca41ff6226eecbf6c9ee475c1de714cb8f04e9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:57:57 +0200 Subject: [PATCH 153/181] chore(api-nodes): remove check for pyav>=14.2 in code (it was added to requirements.txt long ago) (#11934) --- comfy_api_nodes/canary.py | 10 ---------- nodes.py | 3 --- 2 files changed, 13 deletions(-) delete mode 100644 comfy_api_nodes/canary.py diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py deleted file mode 100644 index 4df7590b6..000000000 --- a/comfy_api_nodes/canary.py +++ /dev/null @@ -1,10 +0,0 @@ -import av - -ver = av.__version__.split(".") -if int(ver[0]) < 14: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -if int(ver[0]) == 14 and int(ver[1]) < 2: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -NODE_CLASS_MAPPINGS = {} diff --git a/nodes.py b/nodes.py index f19d5fd1c..8b5279b36 100644 --- a/nodes.py +++ b/nodes.py @@ -2409,9 +2409,6 @@ async def init_builtin_api_nodes(): "nodes_wan.py", ] - if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): - return api_nodes_files - import_failed = [] for node_file in api_nodes_files: if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): From a498556d0dcde3d7a7c19e1f5c733c8c2a2ffb10 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 17 Jan 2026 19:06:03 -0800 Subject: [PATCH 154/181] feat: add advanced parameter to Input classes for advanced widgets support (#11939) Add 'advanced' boolean parameter to Input and WidgetInput base classes and propagate to all typed Input subclasses (Boolean, Int, Float, String, Combo, MultiCombo, Webcam, MultiType, MatchType, ImageCompare). When set to True, the frontend will hide these inputs by default in a collapsible 'Advanced Inputs' section in the right side panel, reducing visual clutter for power-user options. This enables nodes to expose advanced configuration options (like encoding parameters, quality settings, etc.) without overwhelming typical users. Frontend support: ComfyUI_frontend PR #7812 --- comfy_api/latest/_io.py | 47 ++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e6a0d1821..c30d92aaa 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -153,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -162,6 +162,7 @@ class Input(_IO_V3): self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} self.rawLink = raw_link + self.advanced = advanced def as_dict(self): return prune_dict({ @@ -170,6 +171,7 @@ class Input(_IO_V3): "tooltip": self.tooltip, "lazy": self.lazy, "rawLink": self.rawLink, + "advanced": self.advanced, }) | prune_dict(self.extra_dict) def get_io_type(self): @@ -184,8 +186,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -242,8 +244,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.label_on = label_on self.label_off = label_off self.default: bool @@ -262,8 +264,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -288,8 +290,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -314,8 +316,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -350,12 +352,13 @@ class Combo(ComfyTypeIO): socketless: bool=None, extra_dict=None, raw_link: bool=None, + advanced: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,8 +390,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) + socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -421,9 +424,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) @comfytype(io_type="MASK") @@ -776,7 +779,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -789,7 +792,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self._io_types = types @property @@ -843,8 +846,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.template = template def as_dict(self): @@ -1119,8 +1122,8 @@ class ImageCompare(ComfyTypeI): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True): - super().__init__(id, display_name, optional, tooltip, None, None, socketless) + socketless: bool=True, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced) def as_dict(self): return super().as_dict() From 034fac70549dd9c35b155b80a3ff627ad07b1015 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 08:40:39 +0200 Subject: [PATCH 155/181] chore(api-nodes): auto-discover all nodes_*.py files to avoid merge conflicts when adding new API nodes (#11943) --- nodes.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/nodes.py b/nodes.py index 8b5279b36..cba8eacc2 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import torch import os import sys import json +import glob import hashlib import inspect import traceback @@ -2384,35 +2385,12 @@ async def init_builtin_extra_nodes(): async def init_builtin_api_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") - api_nodes_files = [ - "nodes_ideogram.py", - "nodes_openai.py", - "nodes_minimax.py", - "nodes_veo2.py", - "nodes_kling.py", - "nodes_bfl.py", - "nodes_bytedance.py", - "nodes_ltxv.py", - "nodes_luma.py", - "nodes_recraft.py", - "nodes_pixverse.py", - "nodes_stability.py", - "nodes_runway.py", - "nodes_sora.py", - "nodes_topaz.py", - "nodes_tripo.py", - "nodes_meshy.py", - "nodes_moonvalley.py", - "nodes_rodin.py", - "nodes_gemini.py", - "nodes_vidu.py", - "nodes_wan.py", - ] + api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py"))) import_failed = [] for node_file in api_nodes_files: - if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): - import_failed.append(node_file) + if not await load_custom_node(node_file, module_parent="comfy_api_nodes"): + import_failed.append(os.path.basename(node_file)) return import_failed From 1a72bf20469dee31ad156f819c14f0172cbad222 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 18 Jan 2026 19:53:43 -0800 Subject: [PATCH 156/181] Readme update. (#11957) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 123cc9472..c56e05d07 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Works fully offline: core will never download anything unless you want to. -- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes` - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: @@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` From 866a4619db2db56c77a86e5fc9968a2454928627 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 06:21:35 +0800 Subject: [PATCH 157/181] chore: update workflow templates to v0.8.14 (#11974) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 622256973..312c7c137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.11 +comfyui-workflow-templates==0.8.14 comfyui-embedded-docs==0.4.0 torch torchsde From b931b37e30bb19b6e13ad8623e193ccdaf671a23 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:47:14 -0800 Subject: [PATCH 158/181] feat(api-nodes): add Bria Edit node (#11978) Co-authored-by: Alexander Piskun --- comfy_api_nodes/apis/bria.py | 61 +++++++++ comfy_api_nodes/nodes_bria.py | 198 ++++++++++++++++++++++++++++ comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 6 + 4 files changed, 267 insertions(+) create mode 100644 comfy_api_nodes/apis/bria.py create mode 100644 comfy_api_nodes/nodes_bria.py diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py new file mode 100644 index 000000000..9119cacc6 --- /dev/null +++ b/comfy_api_nodes/apis/bria.py @@ -0,0 +1,61 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputModerationSettings(TypedDict): + prompt_content_moderation: bool + visual_input_moderation: bool + visual_output_moderation: bool + + +class BriaEditImageRequest(BaseModel): + instruction: str | None = Field(...) + structured_instruction: str | None = Field( + ..., + description="Use this instead of instruction for precise, programmatic control.", + ) + images: list[str] = Field( + ..., + description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.", + ) + mask: str | None = Field( + None, + description="Mask image (black and white). Black areas will be preserved, white areas will be edited. " + "If omitted, the edit applies to the entire image. " + "The input image and the the input mask must be of the same size.", + ) + negative_prompt: str | None = Field(None) + guidance_scale: float = Field(...) + model_version: str = Field(...) + steps_num: int = Field(...) + seed: int = Field(...) + ip_signal: bool = Field( + False, + description="If true, returns a warning for potential IP content in the instruction.", + ) + prompt_content_moderation: bool = Field( + False, description="If true, returns 422 on instruction moderation failure." + ) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on images or mask moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + + +class BriaStatusResponse(BaseModel): + request_id: str = Field(...) + status_url: str = Field(...) + warning: str | None = Field(None) + + +class BriaResult(BaseModel): + structured_prompt: str = Field(...) + image_url: str = Field(...) + + +class BriaResponse(BaseModel): + status: str = Field(...) + result: BriaResult | None = Field(None) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py new file mode 100644 index 000000000..72a3055a7 --- /dev/null +++ b/comfy_api_nodes/nodes_bria.py @@ -0,0 +1,198 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bria import ( + BriaEditImageRequest, + BriaResponse, + BriaStatusResponse, + InputModerationSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + convert_mask_to_image, + download_url_to_image_tensor, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, +) + + +class BriaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaImageEditNode", + display_name="Bria Image Edit", + category="api node/image/Bria", + description="Edit images using Bria latest model", + inputs=[ + IO.Combo.Input("model", options=["FIBO"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.String.Input( + "structured_prompt", + multiline=True, + default="", + tooltip="A string containing the structured edit prompt in JSON format. " + "Use this instead of usual prompt for precise, programmatic control.", + ), + IO.Int.Input( + "seed", + default=1, + min=1, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Float.Input( + "guidance_scale", + default=3, + min=3, + max=5, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely.", + ), + IO.Int.Input( + "steps", + default=50, + min=20, + max=50, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "prompt_content_moderation", default=False + ), + IO.Boolean.Input( + "visual_input_moderation", default=False + ), + IO.Boolean.Input( + "visual_output_moderation", default=True + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Moderation settings", + ), + IO.Mask.Input( + "mask", + tooltip="If omitted, the edit applies to the entire image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(display_name="structured_prompt"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str, + structured_prompt: str, + seed: int, + guidance_scale: float, + steps: int, + moderation: InputModerationSettings, + mask: Input.Image | None = None, + ) -> IO.NodeOutput: + if not prompt and not structured_prompt: + raise ValueError( + "One of prompt or structured_prompt is required to be non-empty." + ) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + mask_url = None + if mask is not None: + mask_url = ( + await upload_images_to_comfyapi( + cls, + convert_mask_to_image(mask), + max_images=1, + mime_type="image/png", + wait_label="Uploading mask", + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), + data=BriaEditImageRequest( + instruction=prompt if prompt else None, + structured_instruction=structured_prompt if structured_prompt else None, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + wait_label="Uploading image", + ), + mask=mask_url, + negative_prompt=negative_prompt if negative_prompt else None, + guidance_scale=guidance_scale, + seed=seed, + model_version=model, + steps_num=steps, + prompt_content_moderation=moderation.get( + "prompt_content_moderation", False + ), + visual_input_content_moderation=moderation.get( + "visual_input_moderation", False + ), + visual_output_content_moderation=moderation.get( + "visual_output_moderation", False + ), + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaResponse, + ) + return IO.NodeOutput( + await download_url_to_image_tensor(response.result.image_url), + response.result.structured_prompt, + ) + + +class BriaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BriaImageEditNode, + ] + + +async def comfy_entrypoint() -> BriaExtension: + return BriaExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 4cc22abfb..364976000 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -11,6 +11,7 @@ from .conversions import ( audio_input_to_mp3, audio_to_base64_string, bytesio_to_image_tensor, + convert_mask_to_image, downscale_image_tensor, image_tensor_pair_to_batch, pil_to_bytesio, @@ -72,6 +73,7 @@ __all__ = [ "audio_input_to_mp3", "audio_to_base64_string", "bytesio_to_image_tensor", + "convert_mask_to_image", "downscale_image_tensor", "image_tensor_pair_to_batch", "pil_to_bytesio", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 99c302a2a..546741b7b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -451,6 +451,12 @@ def resize_mask_to_image( return mask +def convert_mask_to_image(mask: Input.Image) -> torch.Tensor: + """Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.""" + mask = mask.unsqueeze(-1) + return torch.cat([mask] * 3, dim=-1) + + def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: From 7458e20465a0efcf91eafc0c65d1929ab7b2238d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:58:30 -0800 Subject: [PATCH 159/181] Make Autogrow validation work properly (#11977) * In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node) * Allow autogrow to work with all inputs being optional * Revert accidentally pushed changes to nodes_logic.py --- comfy_api/latest/_io.py | 53 ++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index c30d92aaa..4969d3506 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI): names = [f"{prefix}{i}" for i in range(max)] # need to create a new input based on the contents of input template_input = None - for _, dict_input in input.items(): - # for now, get just the first value from dict_input + template_required = True + for _input_type, dict_input in input.items(): + # for now, get just the first value from dict_input; if not required, min can be ignored + if len(dict_input) == 0: + continue template_input = list(dict_input.values())[0] + template_required = _input_type == "required" + break + if template_input is None: + raise Exception("template_input could not be determined from required or optional; this should never happen.") new_dict = {} + new_dict_added_to = False + # first, add possible inputs into out_dict for i, name in enumerate(names): expected_id = finalize_prefix(curr_prefix, name) + # required + if i < min and template_required: + out_dict["required"][expected_id] = template_input + type_dict = new_dict.setdefault("required", {}) + # optional + else: + out_dict["optional"][expected_id] = template_input + type_dict = new_dict.setdefault("optional", {}) if expected_id in live_inputs: - # required - if i < min: - type_dict = new_dict.setdefault("required", {}) - # optional - else: - type_dict = new_dict.setdefault("optional", {}) + # NOTE: prefix gets added in parse_class_inputs type_dict[name] = template_input + new_dict_added_to = True + # account for the edge case that all inputs are optional and no values are received + if not new_dict_added_to: + finalized_prefix = finalize_prefix(curr_prefix) + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @@ -1151,6 +1169,8 @@ class V3Data(TypedDict): 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + dynamic_paths_default_value: dict[str, Any] + 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' create_dynamic_tuple: bool 'When True, the value of the dynamic input will be in the format (value, path_key).' @@ -1504,6 +1524,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i "required": {}, "optional": {}, "dynamic_paths": {}, + "dynamic_paths_default_value": {}, } d = d.copy() # ignore hidden for parsing @@ -1513,8 +1534,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i out_dict["hidden"] = hidden v3_data = {} dynamic_paths = out_dict.pop("dynamic_paths", None) - if dynamic_paths is not None: + if dynamic_paths is not None and len(dynamic_paths) > 0: v3_data["dynamic_paths"] = dynamic_paths + # this list is used for autogrow, in the case all inputs are optional and no values are passed + dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) + if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: + v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value return out_dict, hidden, v3_data def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: @@ -1551,11 +1576,16 @@ def add_to_dict_v1(i: Input, d: dict): def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +class DynamicPathsDefaultValue: + EMPTY_DICT = "empty_dict" + def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) + default_value_dict = v3_data.get("dynamic_paths_default_value", {}) if paths is None: return values values = values.copy() + result = {} create_tuple = v3_data.get("create_dynamic_tuple", False) @@ -1569,6 +1599,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): if is_last: value = values.pop(key, None) + if value is None: + # see if a default value was provided for this key + default_option = default_value_dict.get(key, None) + if default_option == DynamicPathsDefaultValue.EMPTY_DICT: + value = {} if create_tuple: value = (value, key) current[p] = value From e0eacb06883c1f7ddf8af249cd461d7c2ebcbaae Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:00:36 -0800 Subject: [PATCH 160/181] Simpler way to implement the #11980 loras. (#11981) --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 2e33a4258..5e79fb449 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -639,6 +639,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "proj_out.bias": "linear2.bias", "attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale", + "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 + "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: From 0da5a0fe58ae940726a61b94698e303fb39d73c1 Mon Sep 17 00:00:00 2001 From: rkfg Date: Tue, 20 Jan 2026 06:12:02 +0300 Subject: [PATCH 161/181] Convert mono audio to fake stereo for LTXV VAE encoding (#11965) --- comfy/ldm/lightricks/vae/audio_vae.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index a9111d3bd..29d9e6c29 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module): waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: - raise ValueError( - f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" - ) + if waveform.shape[1] == 1: + waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:]) + else: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) mel_spec = self.preprocessor.waveform_to_mel( waveform, waveform_sample_rate, device=self.device_manager.load_device From 70c91b8248e08492cf16bfebdc83579b801a6ee0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:32:40 -0800 Subject: [PATCH 162/181] Fix #11963 (#11982) --- comfy/text_encoders/ovis.py | 1 + comfy/text_encoders/z_image.py | 1 + 2 files changed, 2 insertions(+) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 5754424d2..2cc0867c3 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return OvisTEModel_ diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 19adde0b7..ad41bfb1e 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return ZImageTEModel_ From 9d273d3ab1fb1d2c8b34de4d54cabe50a5a3e5bc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Jan 2026 22:40:18 -0500 Subject: [PATCH 163/181] ComfyUI v0.10.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index dbb57b4e5..952d413db 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.2" +__version__ = "0.10.0" diff --git a/pyproject.toml b/pyproject.toml index 9ea73da05..120b6c751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.2" +version = "0.10.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 2108167f9f70cfd4874945b31a916680f959a6d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 20:17:38 -0800 Subject: [PATCH 164/181] Support zimage omni base model. (#11979) --- comfy/ldm/lumina/model.py | 317 ++++++++++++++++++++++++++++------- comfy/model_base.py | 30 ++++ comfy/model_detection.py | 3 + comfy_extras/nodes_zimage.py | 88 ++++++++++ nodes.py | 1 + 5 files changed, 381 insertions(+), 58 deletions(-) create mode 100644 comfy_extras/nodes_zimage.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..139f879a1 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension +import comfy.utils -def modulate(x, scale): - return x * (1 + scale.unsqueeze(1)) +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + +def modulate(x, scale, timestep_zero_index=None): + if timestep_zero_index is None: + return x * (1 + scale.unsqueeze(1)) + else: + scale = (1 + scale.unsqueeze(1)) + actual_batch = scale.size(0) // 2 + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= scale[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= scale[:actual_batch] + return x + + +def apply_gate(gate, x, timestep_zero_index=None): + if timestep_zero_index is None: + return gate * x + else: + actual_batch = gate.size(0) // 2 + + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= gate[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= gate[:actual_batch] + return x ############################################################################# # Core NextDiT Model # @@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + timestep_zero_index=None, transformer_options={}, ): """ @@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module): assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2( clamp_fp16(self.attention( - modulate(self.attention_norm1(x), scale_msa), + modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index), x_mask, freqs_cis, transformer_options=transformer_options, - )) + ))), timestep_zero_index=timestep_zero_index ) - x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2( clamp_fp16(self.feed_forward( - modulate(self.ffn_norm1(x), scale_mlp), - )) + modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index), + ))), timestep_zero_index=timestep_zero_index ) else: assert adaln_input is None @@ -345,13 +389,37 @@ class FinalLayer(nn.Module): ), ) - def forward(self, x, c): + def forward(self, x, c, timestep_zero_index=None): scale = self.adaLN_modulation(c) - x = modulate(self.norm_final(x), scale) + x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index) x = self.linear(x) return x +def pad_zimage(feats, pad_token, pad_tokens_multiple): + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra + + +def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}): + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + return x_pos_ids + + class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -378,6 +446,7 @@ class NextDiT(nn.Module): time_scale=1.0, pad_tokens_multiple=None, clip_text_dim=None, + siglip_feat_dim=None, image_model=None, device=None, dtype=None, @@ -491,6 +560,41 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) + + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + siglip_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.siglip_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + # This norm final is in the lumina 2.0 code but isn't actually used for anything. # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) @@ -531,70 +635,166 @@ class NextDiT(nn.Module): imgs = torch.stack(imgs, dim=0) return imgs - def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} - ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - orig_x = x - - if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None): + if cap_feats is not None: + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + else: + cap_feats_len = 0 + cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) - cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + embeds = (cap_feats,) + freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + + def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}): + bsz = 1 + pH = pW = self.patch_size + device = x.device + embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) + + if not omni: + cap_feats_len = embeds[0].shape[1] + offset + embeds += (None,) + freqs_cis += (None,) + else: + cap_feats_len += offset + if siglip_feats is not None: + b, h, w, c = siglip_feats.shape + siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + siglip_pos_ids[:, :, 0] = cap_feats_len + 2 + siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten() + siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten() + if self.siglip_pad_token is not None: + siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check + siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) + else: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + + if siglip_feats is None: + embeds += (None,) + freqs_cis += (None,) + else: + embeds += (siglip_feats,) + freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),) B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - H_tokens, W_tokens = H // pH, W // pW - x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) - x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - + x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options) if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + embeds += (x,) + freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1 + + + def patchify_and_embed( + self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={} + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = x.shape[0] + cap_mask = None # TODO? + main_siglip = None + orig_x = x + + embeds = ([], [], []) + freqs_cis = ([], [], []) + leftover_cap = [] + + start_t = 0 + omni = len(ref_latents) > 0 + if omni: + for i, ref in enumerate(ref_latents): + if i < len(ref_contexts): + ref_con = ref_contexts[i] + else: + ref_con = None + if i < len(siglip_feats): + sig_feat = siglip_feats[i] + else: + sig_feat = None + + out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) + for i, e in enumerate(out[0]): + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + leftover_cap = ref_contexts[len(ref_latents):] + + H, W = x.shape[-2], x.shape[-1] + img_sizes = [(H, W)] * bsz + out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options) + img_len = out[0][-1].shape[1] + cap_len = out[0][0].shape[1] + for i, e in enumerate(out[0]): + if e is not None: + e = comfy.utils.repeat_to_batch_size(e, bsz) + embeds[i].append(e) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + + for cap in leftover_cap: + out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype) + cap_len += out[0][0].shape[1] + embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz)) + freqs_cis[0].append(out[1][0]) + start_t += out[2] patches = transformer_options.get("patches", {}) # refine context + cap_feats = torch.cat(embeds[0], dim=1) + cap_freqs_cis = torch.cat(freqs_cis[0], dim=1) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + + feats = (cap_feats,) + fc = (cap_freqs_cis,) + + if omni: + siglip_mask = None + siglip_feats_combined = torch.cat(embeds[1], dim=1) + siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) + if self.siglip_refiner is not None: + for layer in self.siglip_refiner: + siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options) + feats += (siglip_feats_combined,) + fc += (siglip_feats_freqs_cis,) padded_img_mask = None + x = torch.cat(embeds[-1], dim=1) + fc_x = torch.cat(freqs_cis[-1], dim=1) + if omni: + timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])] + else: + timestep_zero_index = None + x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] - padded_full_embed = torch.cat((cap_feats, x), dim=1) + padded_full_embed = torch.cat(feats + (x,), dim=1) + if timestep_zero_index is not None: + ind = padded_full_embed.shape[1] - x.shape[1] + timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])] + timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1])) + mask = None - img_sizes = [(H, W)] * bsz - l_effective_cap_len = [cap_feats.shape[1]] * bsz - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -604,7 +804,11 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -619,8 +823,6 @@ class NextDiT(nn.Module): t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: @@ -632,7 +834,7 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) transformer_options["total_blocks"] = len(self.layers) @@ -640,7 +842,7 @@ class NextDiT(nn.Module): img_input = img for i, layer in enumerate(self.layers): transformer_options["block_index"] = i - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) @@ -649,8 +851,7 @@ class NextDiT(nn.Module): if "txt" in out: img[:, :cap_size[0]] = out["txt"] - img = self.final_layer(img, adaln_input) + img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index) img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -img diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..28ba2643e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1150,6 +1150,7 @@ class CosmosPredict2(BaseModel): class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1169,6 +1170,35 @@ class Lumina2(BaseModel): if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni + if clip_vision_outputs is not None and len(clip_vision_outputs) > 0: + sigfeats = [] + for clip_vision_output in clip_vision_outputs: + if clip_vision_output is not None: + image_size = clip_vision_output.image_sizes[0] + shape = clip_vision_output.last_hidden_state.shape + sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1])) + if len(sigfeats) > 0: + out['siglip_feats'] = comfy.conds.CONDList(sigfeats) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_contexts = kwargs.get("reference_latents_text_embeds", None) + if ref_contexts is not None: + out['ref_contexts'] = comfy.conds.CONDList(ref_contexts) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index aff5a50b9..42884f797 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["time_scale"] = 1000.0 if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 + sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) + if sig_weight is not None: + dit_config["siglip_feat_dim"] = sig_weight.shape[0] return dit_config diff --git a/comfy_extras/nodes_zimage.py b/comfy_extras/nodes_zimage.py new file mode 100644 index 000000000..2ee3c43b1 --- /dev/null +++ b/comfy_extras/nodes_zimage.py @@ -0,0 +1,88 @@ +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import math +import comfy.utils + + +class TextEncodeZImageOmni(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeZImageOmni", + category="advanced/conditioning", + is_experimental=True, + inputs=[ + io.Clip.Input("clip"), + io.ClipVision.Input("image_encoder", optional=True), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Boolean.Input("auto_resize_images", default=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: + ref_latents = [] + images = list(filter(lambda a: a is not None, [image1, image2, image3])) + + prompt_list = [] + template = None + if len(images) > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1) + prompt_list += ["<|vision_end|><|im_end|>"] + template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>" + + encoded_images = [] + + for i, image in enumerate(images): + if image_encoder is not None: + encoded_images.append(image_encoder.encode_image(image)) + + if vae is not None: + if auto_resize_images: + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + ref_latents.append(vae.encode(image)) + + tokens = clip.tokenize(prompt, llama_template=template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + extra_text_embeds = [] + for p in prompt_list: + tokens = clip.tokenize(p, llama_template="{}") + text_embeds = clip.encode_from_tokens_scheduled(tokens) + extra_text_embeds.append(text_embeds[0][0]) + + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + if len(encoded_images) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True) + if len(extra_text_embeds) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True) + + return io.NodeOutput(conditioning) + + +class ZImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeZImageOmni, + ] + + +async def comfy_entrypoint() -> ZImageExtension: + return ZImageExtension() diff --git a/nodes.py b/nodes.py index cba8eacc2..ea5d6e525 100644 --- a/nodes.py +++ b/nodes.py @@ -2373,6 +2373,7 @@ async def init_builtin_extra_nodes(): "nodes_kandinsky5.py", "nodes_wanmove.py", "nodes_image_compare.py", + "nodes_zimage.py", ] import_failed = [] From 0fc3b6e3a6f1d8fdffca3a51cb4d10a06f4e079d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 12:17:56 +0800 Subject: [PATCH 165/181] chore: update workflow templates to v0.8.15 (#11984) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 312c7c137..35543525d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.14 +comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch torchsde From 4edb87aa50190139a38a2ccd6b6ee35ba9df4da1 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 20 Jan 2026 13:57:50 +0900 Subject: [PATCH 166/181] Bump comfyui-frontend-package to 1.37.11 (#11976) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35543525d..ec89dccd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.36.14 +comfyui-frontend-package==1.37.11 comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch From 8ccc0c94fa0d8e43fffe7190e6a36551a53df54a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:32:00 -0800 Subject: [PATCH 167/181] Make omni stuff work on regular z image for easier testing. (#11985) --- comfy/ldm/lumina/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 139f879a1..b114d9e31 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -657,7 +657,7 @@ class NextDiT(nn.Module): device = x.device embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) - if not omni: + if (not omni) or self.siglip_embedder is None: cap_feats_len = embeds[0].shape[1] + offset embeds += (None,) freqs_cis += (None,) @@ -675,8 +675,9 @@ class NextDiT(nn.Module): siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) else: - siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) - siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + if self.siglip_pad_token is not None: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) if siglip_feats is None: embeds += (None,) @@ -724,8 +725,9 @@ class NextDiT(nn.Module): out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) for i, e in enumerate(out[0]): - embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) - freqs_cis[i].append(out[1][i]) + if e is not None: + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) start_t = out[2] leftover_cap = ref_contexts[len(ref_latents):] @@ -759,7 +761,7 @@ class NextDiT(nn.Module): feats = (cap_feats,) fc = (cap_freqs_cis,) - if omni: + if omni and len(embeds[1]) > 0: siglip_mask = None siglip_feats_combined = torch.cat(embeds[1], dim=1) siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) From ddc541ffdae0fe626de5a33192001f31c6ab93c6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 20 Jan 2026 23:05:40 +0200 Subject: [PATCH 168/181] feat(api-nodes): add WaveSpeed nodes (#11945) --- comfy_api_nodes/apis/wavespeed.py | 35 ++++++ comfy_api_nodes/nodes_wavespeed.py | 178 +++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 comfy_api_nodes/apis/wavespeed.py create mode 100644 comfy_api_nodes/nodes_wavespeed.py diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py new file mode 100644 index 000000000..07a7bfa5d --- /dev/null +++ b/comfy_api_nodes/apis/wavespeed.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class SeedVR2ImageRequest(BaseModel): + image: str = Field(...) + target_resolution: str = Field(...) + output_format: str = Field("png") + enable_sync_mode: bool = Field(False) + + +class FlashVSRRequest(BaseModel): + target_resolution: str = Field(...) + video: str = Field(...) + duration: float = Field(...) + + +class TaskCreatedDataResponse(BaseModel): + id: str = Field(...) + + +class TaskCreatedResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreatedDataResponse | None = Field(None) + + +class TaskResultDataResponse(BaseModel): + status: str = Field(...) + outputs: list[str] = Field([]) + + +class TaskResultResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py new file mode 100644 index 000000000..c59fafd3b --- /dev/null +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -0,0 +1,178 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.wavespeed import ( + FlashVSRRequest, + TaskCreatedResponse, + TaskResultResponse, + SeedVR2ImageRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_video_duration, + upload_images_to_comfyapi, + get_number_of_images, + download_url_to_image_tensor, +) + + +class WavespeedFlashVSRNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedFlashVSRNode", + display_name="FlashVSR Video Upscale", + category="api node/video/WaveSpeed", + description="Fast, high-quality video upscaler that " + "boosts resolution and restores clarity for low-resolution or blurry footage.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), + expr=""" + ( + $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; + { + "type":"usd", + "usd": $lookup($price_for_1sec, widgets.target_resolution), + "format":{"suffix": "/second", "approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + target_resolution: str, + ) -> IO.NodeOutput: + validate_container_format_is_mp4(video) + validate_video_duration(video, min_duration=5, max_duration=60 * 10) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), + response_model=TaskCreatedResponse, + data=FlashVSRRequest( + target_resolution=target_resolution.lower(), + video=await upload_video_to_comfyapi(cls, video), + duration=video.get_duration(), + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) + + +class WavespeedImageUpscaleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedImageUpscaleNode", + display_name="WaveSpeed Image Upscale", + category="api node/image/WaveSpeed", + description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", + inputs=[ + IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), + IO.Image.Input("image"), + IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := {"seedvr2": 0.01, "ultimate": 0.06}; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + target_resolution: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if model == "SeedVR2": + model_path = "seedvr2/image" + else: + model_path = "ultimate-image-upscaler" + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), + response_model=TaskCreatedResponse, + data=SeedVR2ImageRequest( + target_resolution=target_resolution.lower(), + image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + +class WavespeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WavespeedFlashVSRNode, + WavespeedImageUpscaleNode, + ] + + +async def comfy_entrypoint() -> WavespeedExtension: + return WavespeedExtension() From 965d0ed509ce46a3328c342aee23a234ba6e4f88 Mon Sep 17 00:00:00 2001 From: Ivan Zorin Date: Wed, 21 Jan 2026 01:44:28 +0200 Subject: [PATCH 169/181] fix: remove normalization of audio in LTX Mel spectrogram creation (#11990) For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation. This aligs inference with training and prevents loud audio from being attenuated. --- comfy/ldm/lightricks/vae/audio_vae.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 29d9e6c29..55a074661 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -103,20 +103,10 @@ class AudioPreprocessor: return waveform return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) - @staticmethod - def normalize_amplitude( - waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 - ) -> torch.Tensor: - waveform = waveform - waveform.mean(dim=2, keepdim=True) - peak = torch.max(torch.abs(waveform)) + eps - scale = peak.clamp(max=max_amplitude) / peak - return waveform * scale - def waveform_to_mel( self, waveform: torch.Tensor, waveform_sample_rate: int, device ) -> torch.Tensor: waveform = self.resample(waveform, waveform_sample_rate) - waveform = self.normalize_amplitude(waveform) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.target_sample_rate, From c4a14df9a35336dbfff096683c5015ce726c269d Mon Sep 17 00:00:00 2001 From: Mylo <36931363+gitmylo@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:46:11 +0100 Subject: [PATCH 170/181] Dynamically detect chroma radiance patch size (#11991) --- comfy/model_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 42884f797..dad206a2f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 - dit_config["patch_size"] = 16 + dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1) dit_config["nerf_hidden_size"] = 64 dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 From e755268e7b7843695f52b87595afcb09c1e9fd87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 20 Jan 2026 20:08:31 -0800 Subject: [PATCH 171/181] Config for Qwen 3 0.6B model. (#11998) --- comfy/text_encoders/llama.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 331a30f61..3080a3e09 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -77,6 +77,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_06BConfig: + vocab_size: int = 151936 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -641,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 0fc15700be9b555f351034942b5bd7243bdf6bcc Mon Sep 17 00:00:00 2001 From: Markury Date: Tue, 20 Jan 2026 23:18:33 -0500 Subject: [PATCH 172/181] Add LyCoris LoKr MLP layer support for Flux2 (#11997) --- comfy/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 5e79fb449..d97d753e6 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -611,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr + "ff.linear_in.bias": "img_mlp.0.bias", + "ff.linear_out.weight": "img_mlp.2.weight", + "ff.linear_out.bias": "img_mlp.2.bias", + "ff_context.linear_in.weight": "txt_mlp.0.weight", + "ff_context.linear_in.bias": "txt_mlp.0.bias", + "ff_context.linear_out.weight": "txt_mlp.2.weight", + "ff_context.linear_out.bias": "txt_mlp.2.bias", "attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", From 451af7015435df22e6313ae79f25fe2ef336a96d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:03:45 +0200 Subject: [PATCH 173/181] fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002) --- comfy_api_nodes/nodes_vidu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 8edb02f39..b9114c4bb 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): "subjects", template=IO.Autogrow.TemplateNames( IO.Image.Input("reference_images"), - names=["subject1", "subject2", "subject3"], + names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"], min=1, ), tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " @@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): control_after_generate=True, ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), - IO.Combo.Input("resolution", options=["720p"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), IO.Combo.Input( "movement_amplitude", options=["auto", "small", "medium", "large"], From bdeac8897e522b9637a6a427fdc8a50a6abd6b20 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 21 Jan 2026 15:36:02 -0800 Subject: [PATCH 174/181] feat: Add search_aliases field to node schema (#12010) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add search_aliases field to node schema Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate). Changes: - Add `search_aliases: list[str]` to V3 Schema - Add `SEARCH_ALIASES` support for V1 nodes - Include field in `/object_info` response - Add aliases to high-priority core nodes V1 usage: ```python class MyNode: SEARCH_ALIASES = ["alt name", "synonym"] ``` V3 usage: ```python io.Schema( node_id="MyNode", search_aliases=["alt name", "synonym"], ... ) ``` ## Related PRs - Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this) - Docs: Comfy-Org/docs#XXXX (draft - merge after stable) * Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1 --- comfy_api/latest/_io.py | 4 ++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_preview_any.py | 1 + comfy_extras/nodes_string.py | 1 + comfy_extras/nodes_upscale_model.py | 1 + nodes.py | 15 +++++++++++++++ server.py | 2 ++ 7 files changed, 25 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4969d3506..a60020ca8 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1249,6 +1249,7 @@ class NodeInfoV1: experimental: bool=None api_node: bool=None price_badge: dict | None = None + search_aliases: list[str]=None @dataclass class NodeInfoV3: @@ -1346,6 +1347,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1483,6 +1486,7 @@ class Schema: api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, ) return info diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 2e559c35c..6011275d6 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -550,6 +550,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93..91502ebf2 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f62..a2d5f0d94 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ed587851c..97b9e948d 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode): node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", category="image/upscaling", + search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), io.Image.Input("image"), diff --git a/nodes.py b/nodes.py index ea5d6e525..67b61dcfe 100644 --- a/nodes.py +++ b/nodes.py @@ -70,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC): CATEGORY = "conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." + SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] def encode(self, clip, text): if clip is None: @@ -86,6 +87,7 @@ class ConditioningCombine: FUNCTION = "combine" CATEGORY = "conditioning" + SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) @@ -294,6 +296,7 @@ class VAEDecode: CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." + SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] def decode(self, vae, samples): latent = samples["samples"] @@ -346,6 +349,7 @@ class VAEEncode: FUNCTION = "encode" CATEGORY = "latent" + SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) @@ -581,6 +585,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." + SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -667,6 +672,7 @@ class LoraLoader: CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." + SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: @@ -814,6 +820,7 @@ class ControlNetLoader: FUNCTION = "load_controlnet" CATEGORY = "loaders" + SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -890,6 +897,7 @@ class ControlNetApplyAdvanced: FUNCTION = "apply_controlnet" CATEGORY = "conditioning/controlnet" + SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: @@ -1200,6 +1208,7 @@ class EmptyLatentImage: CATEGORY = "latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." + SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) @@ -1540,6 +1549,7 @@ class KSampler: CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) @@ -1604,6 +1614,7 @@ class SaveImage: CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." + SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -1640,6 +1651,8 @@ class PreviewImage(SaveImage): self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 + SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"] + @classmethod def INPUT_TYPES(s): return {"required": @@ -1658,6 +1671,7 @@ class LoadImage: } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" @@ -1810,6 +1824,7 @@ class ImageScale: FUNCTION = "upscale" CATEGORY = "image/upscaling" + SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"] def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: diff --git a/server.py b/server.py index 04a577488..1888745b7 100644 --- a/server.py +++ b/server.py @@ -682,6 +682,8 @@ class PromptServer(): if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE + + info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', []) return info @routes.get("/object_info") From abe2ec26a61ff670b9c0e71e4821c873368c8728 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:44:28 -0800 Subject: [PATCH 175/181] Support the Anima model. (#12012) --- comfy/ldm/anima/model.py | 202 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 22 ++++ comfy/model_detection.py | 2 + comfy/sd.py | 7 ++ comfy/supported_models.py | 33 +++++- comfy/text_encoders/anima.py | 61 +++++++++++ 6 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/anima/model.py create mode 100644 comfy/text_encoders/anima.py diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000..2e6ed58fa --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,202 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + x = self.in_proj(self.embed(target_input_ids)) + context = source_hidden_states + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids): + if text_ids is not None: + return self.llm_adapter(text_embeds, text_ids) + else: + return text_embeds diff --git a/comfy/model_base.py b/comfy/model_base.py index 28ba2643e..1d57562cc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -49,6 +49,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model import comfy.model_management import comfy.patcher_extension @@ -1147,6 +1148,27 @@ class CosmosPredict2(BaseModel): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) + if t5xxl_weights is not None: + cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + + if cross_attn.shape[1] < 512: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dad206a2f..b29a033cc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -550,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 diff --git a/comfy/sd.py b/comfy/sd.py index 77700dfd3..f7f6a44a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -57,6 +57,7 @@ import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie +import comfy.text_encoders.anima import comfy.model_patcher import comfy.lora @@ -1048,6 +1049,7 @@ class TEModel(Enum): GEMMA_3_12B = 18 JINA_CLIP_2 = 19 QWEN3_8B = 20 + QWEN3_06B = 21 def detect_te_model(sd): @@ -1093,6 +1095,8 @@ def detect_te_model(sd): return TEModel.QWEN3_2B elif weight.shape[0] == 4096: return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1233,6 +1237,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c8a7f6efb..70abebf46 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima from . import supported_models_base from . import latent_formats @@ -992,6 +993,36 @@ class CosmosT2IPredict2(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1551,6 +1582,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000..41f95bcb6 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,61 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ From f09904720dc8b56bc6823ebdaf5de69465448e46 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Jan 2026 20:01:35 -0800 Subject: [PATCH 176/181] Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) --- comfy_extras/nodes_easycache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdb..90d730df6 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") return easycache.apply_cache_diff(x, uuids) @@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values @@ -240,6 +242,9 @@ class EasyCacheHolder: return to_return.clone() return to_return + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): if self.first_cond_uuid in uuids: self.total_steps_skipped += 1 From 3365ad18a5e0c86b23c6272e5adcedd333fc45cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:03:51 +0200 Subject: [PATCH 177/181] Support LTX2 tiny vae (taeltx_2) (#11929) --- comfy/sd.py | 5 ++--- comfy/taesd/taehv.py | 53 ++++++++++++++++++++++++++++---------------- latent_preview.py | 2 +- nodes.py | 2 +- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f7f6a44a0..ce7e6bcff 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -636,14 +636,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ import logging default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/nodes.py b/nodes.py index 67b61dcfe..8864fda60 100644 --- a/nodes.py +++ b/nodes.py @@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s): From 245f6139b65899112d11ff294d36a820f2d69496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:05:06 +0200 Subject: [PATCH 178/181] More targeted embedding_connector loading for LTX2 text encoder (#11992) Reduces errors --- comfy/text_encoders/lt.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c33c77db7..e49161964 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - missing, unexpected = self.load_state_dict(sdo, strict=False) - missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model - return (missing, unexpected) + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + return (missing_all, unexpected_all) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 16b9aabd52c3b81b365fbf562bbcc4528111ef6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:09:48 +0200 Subject: [PATCH 179/181] Support Multi/InfiniteTalk (#10179) * re-init * Update model_multitalk.py * whitespace... * Update model_multitalk.py * remove print * this is redundant * remove import * Restore preview functionality * Move block_idx to transformer_options * Remove LoopingSamplerCustomAdvanced * Remove looping functionality, keep extension functionality * Update model_multitalk.py * Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn * Chunk attention map calculation for multiple speakers to reduce peak VRAM usage * Update model_multitalk.py * Add ModelPatch type back * Fix for latest upstream * Use DynamicCombo for cleaner node Basically just so that single_speaker mode hides mask inputs and 2nd audio input * Update nodes_wan.py --- comfy/ldm/wan/model.py | 17 +- comfy/ldm/wan/model_multitalk.py | 500 ++++++++++++++++++++++++++++++ comfy_api/latest/_io.py | 3 +- comfy_extras/nodes_model_patch.py | 41 +++ comfy_extras/nodes_wan.py | 169 +++++++++- 5 files changed, 727 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/wan/model_multitalk.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ea123acb4 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ class WanModel(torch.nn.Module): self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ class WanModel(torch.nn.Module): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ class VaceWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000..c9dd98c4d --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a60020ca8..2ec8d6e4b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -2038,6 +2038,7 @@ __all__ = [ "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index f66d28fc9..82c4754a3 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -257,6 +258,14 @@ class ModelPatchLoader: if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -524,6 +533,38 @@ class USOStyleReference: return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d32aad98e..90deb0077 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,9 +8,10 @@ import comfy.latent_formats import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features +class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanInfiniteTalkToVideo", + category="conditioning/video_models", + inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.AudioEncoderOutput.Input("audio_encoder_output_1"), + io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), + io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), + io.Image.Input("previous_frames", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), + ], + ) + + @classmethod + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: + + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") + + if mode["mode"] == "two_speakers": + audio_encoder_output_2 = mode["audio_encoder_output_2"] + mask_1 = mode["mask_1"] + mask_2 = mode["mask_2"] + + if audio_encoder_output_2 is not None: + if mask_1 is None or mask_2 is None: + raise ValueError("Masks must be provided if two audio encoder outputs are used.") + + ref_masks = None + if mask_1 is not None and mask_2 is not None: + if audio_encoder_output_2 is None: + raise ValueError("Second audio encoder output must be provided if two masks are used.") + ref_masks = torch.cat([mask_1, mask_2]) + + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + model_patched = model.clone() + + encoded_audio_list = [] + seq_lengths = [] + + for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]: + if audio_encoder_output is None: + continue + all_layers = audio_encoder_output["encoded_audio_all_layers"] + encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512] + encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512] + encoded_audio_list.append(encoded_audio) + seq_lengths.append(encoded_audio.shape[0]) + + # Pad / combine depending on multi_audio_type + multi_audio_type = "add" + if len(encoded_audio_list) > 1: + if multi_audio_type == "para": + max_len = max(seq_lengths) + padded = [] + for emb in encoded_audio_list: + if emb.shape[0] < max_len: + pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype) + emb = torch.cat([emb, pad], dim=0) + padded.append(emb) + encoded_audio_list = padded + elif multi_audio_type == "add": + total_len = sum(seq_lengths) + full_list = [] + offset = 0 + for emb, seq_len in zip(encoded_audio_list, seq_lengths): + full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype) + full[offset:offset+seq_len] = emb + full_list.append(full) + offset += seq_len + encoded_audio_list = full_list + + token_ref_target_masks = None + if ref_masks is not None: + token_ref_target_masks = torch.nn.functional.interpolate( + ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] + token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) + + # when extending from previous frames + if previous_frames is not None: + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper + model_patched.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, + "infinite_talk_outer_sample", + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, + model_patch, + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension): WanHuMoImageToVideo, WanAnimateToVideo, Wan22ImageToVideoLatent, + WanInfiniteTalkToVideo, ] async def comfy_entrypoint() -> WanExtension: From 72f6be1690868af852a624084a949b785fc056ea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:42:04 +0200 Subject: [PATCH 180/181] chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022) --- comfy_api_nodes/nodes_bria.py | 2 +- comfy_api_nodes/nodes_openai.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 72a3055a7..d3a52bc1b 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="BriaImageEditNode", - display_name="Bria Image Edit", + display_name="Bria FIBO Image Edit", category="api node/image/Bria", description="Edit images using Bria latest model", inputs=[ diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a12acc06b..f05aaab7b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1", + display_name="OpenAI GPT Image 1.5", category="api node/image/OpenAI", - description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", + description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( "prompt", @@ -429,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "model", options=["gpt-image-1", "gpt-image-1.5"], + default="gpt-image-1.5", optional=True, ), ], From 8490eedadfc0ab00cb131bab681059163c2ebbcd Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 22 Jan 2026 12:46:56 -0500 Subject: [PATCH 181/181] add ply & 3dgs format in 3d node (#11474) --- comfy_extras/nodes_load_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 545588ef8..a16b8c8f3 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode): files = [ normalize_path(str(file_path.relative_to(base_path))) for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'} ] return IO.Schema( node_id="Load3D",