From 96e0d0924e027248733bc6e0b8102dcdc8acde33 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:43:24 -0800 Subject: [PATCH 001/119] Add helpful message to portable. (#11671) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 2 +- .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 2 +- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause From 6ffc159bdd56d1ad73e954081def6a7f163e7a7f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:53:43 -0800 Subject: [PATCH 002/119] Update comfy-kitchen version to 0.2.1 (#11672) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c9c0e29e..22cb50e2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.0 +comfy-kitchen>=0.2.1 #non essential dependencies: kornia>=0.7.1 From c3c3e93c5bb3034175c17ef8beeb8fe8626c66ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 13:57:50 -0800 Subject: [PATCH 003/119] Use rope functions from comfy kitchen. (#11674) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/requirements.txt b/requirements.txt index 22cb50e2d..7798cb179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.1 +comfy-kitchen>=0.2.2 #non essential dependencies: kornia>=0.7.1 From c3566c0d765200068d26d0888f035504a50012f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 7 Jan 2026 06:28:29 +0800 Subject: [PATCH 004/119] chore: update workflow templates to v0.7.67 (#11667) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7798cb179..caad0026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.66 +comfyui-workflow-templates==0.7.67 comfyui-embedded-docs==0.3.1 torch torchsde From 023cf13721cac256c323e2226319b766d07b1f36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:33:03 -0800 Subject: [PATCH 005/119] Fix lowvram issue with ltxv2 text encoder. (#11675) --- comfy/ldm/lightricks/embeddings_connector.py | 2 +- comfy/text_encoders/lt.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index f7a43f3c3..06f5ada89 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -276,7 +276,7 @@ class Embeddings1DConnector(nn.Module): max(1024, hidden_states.shape[1]) / self.num_learnable_registers ) learnable_registers = torch.tile( - self.learnable_registers, (num_registers_duplications, 1) + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) ) hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 2c2d453e8..e5964e42b 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -86,17 +86,19 @@ class LTXAVTEModel(torch.nn.Module): ) def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) def reset_clip_options(self): self.gemma3_12b.reset_clip_options() + self.execution_device = None def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device - out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) From 6e9ee55cdd9e0eca6b5144063575b983f3311762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:41:27 -0800 Subject: [PATCH 006/119] Disable ltxav previews. (#11676) --- comfy/latent_formats.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9bbe30b53..cb4f52ce1 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -408,7 +408,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] class LTXAV(LTXV): - pass + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None class HunyuanVideo(LatentFormat): latent_channels = 16 From 2c03884f5fb7fa213161dfe1e9a09a8e8c4b6062 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:07:26 -0800 Subject: [PATCH 007/119] Skip fp4 matrix mult on devices that don't support it. (#11677) --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 22f4de044..928282092 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index f5e1e9230..8f9fdce36 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -493,11 +493,12 @@ from .quant_ops import ( ) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None @@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd @@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and From edee33f55ea27a1931475d3ea788fd6e9a81677b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 19:13:43 -0800 Subject: [PATCH 008/119] Disable comfy kitchen cuda if pytorch cuda less than 13 (#11681) --- comfy/quant_ops.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd737726f..5a17bc6f5 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -13,6 +13,13 @@ try: get_layout_class, ) _CK_AVAILABLE = True + if torch.version.cuda is None: + ck.registry.disable("cuda") + else: + cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) + if cuda_version < (13,): + ck.registry.disable("cuda") + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") From c5cfb34c07048350f472a9a4f1ccbf75a56ed38f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 20:51:45 -0800 Subject: [PATCH 009/119] Update comfy-kitchen version to 0.2.3 (#11685) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index caad0026a..bc8346bcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.2 +comfy-kitchen>=0.2.3 #non essential dependencies: kornia>=0.7.1 From ce0000c4f2a7dba12324585dddb784b43e3cd3d0 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Tue, 6 Jan 2026 21:57:31 -0800 Subject: [PATCH 010/119] Force sequential execution in CI test jobs (#11687) Added max-parallel setting to enforce sequential execution in test jobs. --- .github/workflows/test-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index adfc5dd32..63df2dc3a 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,6 +20,7 @@ jobs: test-stable: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -74,6 +75,7 @@ jobs: test-unix-nightly: strategy: fail-fast: false + max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux] From 79e94544bd7ec0cc7a4e5e6167907e7d781c4b76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 7 Jan 2026 08:04:50 +0200 Subject: [PATCH 011/119] feat(api-nodes): add WAN2.6 ReferenceToVideo (#11644) --- comfy_api_nodes/nodes_wan.py | 160 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 2 +- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 1675fd863..3e04786a9 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,7 +13,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_video_to_comfyapi, validate_audio_duration, + validate_video_duration, ) @@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel): audio_url: str | None = Field(None) +class Reference2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + reference_video_urls: list[str] = Field(...) + + class Txt2ImageParametersField(BaseModel): size: str = Field(...) n: int = Field(1, description="Number of images to generate.") # we support only value=1 @@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel): shot_type: str = Field("single") +class Reference2VideoParametersField(BaseModel): + size: str = Field(...) + duration: int = Field(5, ge=5, le=15) + shot_type: str = Field("single") + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(False) + + class Text2ImageTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2ImageInputField = Field(...) @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel): parameters: Image2VideoParametersField = Field(...) +class Reference2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Reference2VideoInputField = Field(...) + parameters: Reference2VideoParametersField = Field(...) + + class TaskCreationOutputField(BaseModel): task_id: str = Field(...) task_status: str = Field(...) @@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class WanReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanReferenceVideoApi", + display_name="Wan Reference to Video", + category="api node/video/Wan", + description="Use the character and voice from input videos, combined with a prompt, " + "to generate a new video that maintains character consistency.", + inputs=[ + IO.Combo.Input("model", options=["wan2.6-r2v"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese. " + "Use identifiers such as `character1` and `character2` to refer to the reference characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["character1", "character2", "character3"], + min=1, + ), + ), + IO.Combo.Input( + "size", + options=[ + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str, + reference_videos: IO.Autogrow.Type, + size: str, + duration: int, + seed: int, + shot_type: str, + watermark: bool, + ): + reference_video_urls = [] + for i in reference_videos: + validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) + for i in reference_videos: + reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + width, height = RES_IN_PARENS.search(size).groups() + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Reference2VideoTaskCreationRequest( + model=model, + input=Reference2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls + ), + parameters=Reference2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + shot_type=shot_type, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=6, + max_poll_attempts=280, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension): WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, + WanReferenceVideoApi, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b8d33f4d1..f1ed7fe9c 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -119,7 +119,7 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" + filename = f"{uuid.uuid4()}.{container.value.lower()}" # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() From b7d7cc1d496afe3c82279eec74c4d47399aab8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:39:06 -0800 Subject: [PATCH 012/119] Fix fp8 fast issue. (#11688) --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8f9fdce36..cd536e22d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -427,12 +427,12 @@ def fp8_linear(self, input): input = torch.clamp(input, min=-448, max=448, out=input) input_fp8 = input.to(dtype).contiguous() layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) - quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) + quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) From fc0cb10bcbee6e73ed3caf34c27f7bde4559a07f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 04:07:31 -0500 Subject: [PATCH 013/119] ComfyUI v0.8.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1ed60fe5c..750673f08 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/pyproject.toml b/pyproject.toml index a7d159be9..951c2c978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.7.0" +version = "0.8.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From c0c9720d77774ed2c87981da87189fe1c14a57fa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:48:28 -0800 Subject: [PATCH 014/119] Fix stable release workflow not pulling latest comfy kitchen. (#11695) --- .github/workflows/stable-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 28484a9d1..f501b7b31 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -117,7 +117,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt From 3cd7b32f1b7e7e90395cefe7d9f9b1f89276d8ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:15:14 -0800 Subject: [PATCH 015/119] Support gemma 12B with quant weights. (#11696) --- comfy/text_encoders/lt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e5964e42b..130ebaeae 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -36,10 +36,10 @@ class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class Gemma3_12BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -119,12 +119,12 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) -def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) From 48e5ea1dfd23a9cb5d118d7af661b026d66743bc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 15:39:20 -0800 Subject: [PATCH 016/119] model_patcher: Remove confusing load stat (#11710) If the loader passes 1e32 as the usable memory size, it means force the full load. This happens with CPU loads and a few other misc cases. Removing the confusing number and just leave the other details. --- comfy/model_patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..4528814ad 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -790,11 +790,12 @@ class ModelPatcher: for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) From 1c705f7bfb0fb59f6213dfb85ec5d5dc2ce4300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:39:59 +0200 Subject: [PATCH 017/119] Add device selection for LTXAVTextEncoderLoader (#11700) --- comfy_extras/nodes_lt_audio.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 26b0160d2..1966fd1bf 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode): io.Combo.Input( "ckpt_name", options=folder_paths.get_filename_list("checkpoints"), + ), + io.Combo.Input( + "device", + options=["default", "cpu"], ) ], outputs=[io.Clip.Output()], @@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + model_options = {} + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) From 34751fe9f9ade0c715768202c19211dc0c72e760 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:12:15 -0800 Subject: [PATCH 018/119] Lower ltxv text encoder vram use. (#11713) --- comfy/text_encoders/lt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 130ebaeae..dc0694e0e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module): out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device + if comfy.model_management.should_use_bf16(self.execution_device): + out = out.to(device=self.execution_device, dtype=torch.bfloat16) out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) + out = out.float() out_vid = self.video_embeddings_connector(out)[0] out_audio = self.audio_embeddings_connector(out)[0] out = torch.concat((out_vid, out_audio), dim=-1) From 007b87e7ac29e55ce0ad2c436f5ae68f3a078080 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:48:47 -0800 Subject: [PATCH 019/119] Bump required comfy-kitchen version. (#11714) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bc8346bcf..13e95afa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.3 +comfy-kitchen>=0.2.5 #non essential dependencies: kornia>=0.7.1 From 3cd19e99c10a25cf6e6b51b82e3c16c501733b8c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:04:56 -0800 Subject: [PATCH 020/119] Increase ltxav mem estimation by a bit. (#11715) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ee9a79001..d44c0bc37 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -845,7 +845,7 @@ class LTXAV(LTXV): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = 0.055 # TODO + self.memory_usage_factor = 0.061 # TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.LTXAV(self, device=device) From 25bc1b5b57d61930d6ab60d8cf7e9241d26e4fe9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:11:22 -0800 Subject: [PATCH 021/119] Add memory estimation function to ltxav text encoder. (#11716) --- comfy/sd.py | 11 +++++++---- comfy/text_encoders/lt.py | 8 ++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32157e18b..efde3839c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -218,7 +218,7 @@ class CLIP: if unprojected: self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) @@ -266,7 +266,7 @@ class CLIP: if return_pooled == "unprojected": self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] @@ -299,8 +299,11 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip - def load_model(self): - model_management.load_model_gpu(self.patcher) + def load_model(self, tokens={}): + memory_used = 0 + if hasattr(self.cond_stage_model, "memory_estimation_function"): + memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) return self.patcher def get_key_patches(self): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index dc0694e0e..776e25e97 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -121,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) + def memory_estimation_function(self, token_weight_pairs, device=None): + constant = 6.0 + if comfy.model_management.should_use_bf16(device): + constant /= 2.0 + + token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): From b6c79a648a013f477f514f61580d1a06220b15eb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:01:16 -0800 Subject: [PATCH 022/119] ops: Fix offloading with FP8MM performance (#11697) This logic was checking comfy_cast_weights, and going straight to to the forward_comfy_cast_weights implementation without attempting to downscale input to fp8 in the event comfy_cast_weights is set. The main reason comfy_cast_weights would be set would be for async offload, which is not a good reason to nix FP8MM. So instead, and together the underlying exclusions for FP8MM which are: * having a weight_function (usually LowVramPatch) * force_cast_weights (compute dtype override) * the weight is not Quantized * the input is already quantized * the model or layer has MM explictily disabled. If you get past all of those exclusions, quantize the input tensor. Then hand the new input, quantized or not off to forward_comfy_cast_weights to handle it. If the weight is offloaded but input is quantized you will get an offloaded MM8. --- comfy/model_patcher.py | 1 + comfy/ops.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4528814ad..f6b80a40f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -718,6 +718,7 @@ class ModelPatcher: continue cast_weight = self.force_cast_weights + m.comfy_force_cast_weights = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] diff --git a/comfy/ops.py b/comfy/ops.py index cd536e22d..8156c42ff 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec run_every_op() input_shape = input.shape - tensor_3d = input.ndim == 3 - - if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) + reshaped_3d = False if (getattr(self, 'layout_type', None) is not None and - not isinstance(input, QuantizedTensor)): + not isinstance(input, QuantizedTensor) and not self._full_precision_mm and + not getattr(self, 'comfy_force_cast_weights', False) and + len(self.weight_function) == 0 and len(self.bias_function) == 0): # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - if tensor_3d: - input = input.reshape(-1, input_shape[2]) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - if input.ndim != 2: - # Fall back to comfy_cast_weights for non-2D tensors - return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - # dtype is now implicit in the layout class - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) - - output = self._forward(input, self.weight, self.bias) + output = self.forward_comfy_cast_weights(input) # Reshape output back to 3D if input was 3D - if tensor_3d: + if reshaped_3d: output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) return output From 21e842508733809354a7b04944b2995ed1169370 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:07:26 -0800 Subject: [PATCH 023/119] Add warning for old pytorch. (#11718) --- comfy/quant_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 5a17bc6f5..8324be42a 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -19,6 +19,7 @@ try: cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) if cuda_version < (13,): ck.registry.disable("cuda") + logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") ck.registry.disable("triton") for k, v in ck.list_backends().items(): From fcd9a236b091bd4e77b177134ddfcf7d7dbd71fd Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 8 Jan 2026 10:22:23 +0800 Subject: [PATCH 024/119] Update template to 0.7.69 (#11719) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 13e95afa0..49567ad61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.67 +comfyui-workflow-templates==0.7.69 comfyui-embedded-docs==0.3.1 torch torchsde From ac12f77bed7bbbaf20289533bf7c0bff275e4a41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 22:10:08 -0500 Subject: [PATCH 025/119] ComfyUI version v0.8.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 750673f08..4eb6070fe 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.0" +__version__ = "0.8.1" diff --git a/pyproject.toml b/pyproject.toml index 951c2c978..0037abd6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.0" +version = "0.8.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 50d6e1caf401bf72dca1e9df7e194e722e1bd98b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 20:07:05 -0800 Subject: [PATCH 026/119] Tweak ltxv vae mem estimation. (#11722) --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index efde3839c..5a7221620 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,8 +479,8 @@ class VAE: self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 - self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) From 2e9d51680a90bca9cc375ba7767f7bf3ed27d563 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 23:50:02 -0500 Subject: [PATCH 027/119] ComfyUI version v0.8.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4eb6070fe..df82ed4fc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.1" +__version__ = "0.8.2" diff --git a/pyproject.toml b/pyproject.toml index 0037abd6c..49f1a03fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.1" +version = "0.8.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From a60b7b86c54ea1498e9c5a5c3d6018c0714654d9 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Wed, 7 Jan 2026 21:41:57 -0800 Subject: [PATCH 028/119] Revert "Force sequential execution in CI test jobs (#11687)" (#11725) This reverts commit ce0000c4f2a7dba12324585dddb784b43e3cd3d0. --- .github/workflows/test-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 63df2dc3a..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,7 +20,6 @@ jobs: test-stable: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -75,7 +74,6 @@ jobs: test-unix-nightly: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux] From 5943fbf457d78becbb924a74780e0efc68505a17 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Fri, 9 Jan 2026 01:15:42 +0900 Subject: [PATCH 029/119] bump comfyui_manager version to the 4.0.5 (#11732) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 6585b0c19..bea6d4927 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.4 +comfyui_manager==4.0.5 From 0f11869d55c7a459371b8114b1345e55a0274723 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:16:58 -0800 Subject: [PATCH 030/119] Better detection if AMD torch compiled with efficient attention. (#11745) --- comfy/model_management.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 928282092..e5de4a5b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,7 +22,6 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys -import importlib import platform import weakref import gc @@ -349,10 +348,22 @@ try: except: rocm_version = (6, -1) + def aotriton_supported(gpu_arch): + path = torch.__path__[0] + path = os.path.join(os.path.join(path, "lib"), "aotriton.images") + gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path)))) + if gpu_arch in gfx: + return True + if "{}x".format(gpu_arch[:-1]) in gfx: + return True + if "{}xx".format(gpu_arch[:-2]) in gfx: + return True + return False + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True From 1a206564487d672561d83ce3eb007517bf018995 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:23:59 -0800 Subject: [PATCH 031/119] Fix import issue. (#11746) --- comfy/ldm/hunyuan_video/upsampler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index d9e76922f..51b6d1da8 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,8 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management -import model_patcher +import comfy.model_management +import comfy.model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): @@ -103,13 +103,13 @@ UPSAMPLERS = { class HunyuanVideo15SRModel(): def __init__(self, model_type, config): - self.load_device = model_management.vae_device() - offload_device = model_management.vae_offload_device() - self.dtype = model_management.vae_dtype(self.load_device) + self.load_device = comfy.model_management.vae_device() + offload_device = comfy.model_management.vae_offload_device() + self.dtype = comfy.model_management.vae_dtype(self.load_device) self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=True) @@ -118,5 +118,5 @@ class HunyuanVideo15SRModel(): return self.model.state_dict() def resample_latent(self, latent): - model_management.load_model_gpu(self.patcher) + comfy.model_management.load_model_gpu(self.patcher) return self.model(latent.to(self.load_device)) From 027042db6811c875562296f0a6b797c89d59e426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 9 Jan 2026 05:14:06 +0200 Subject: [PATCH 032/119] Add node: JoinAudioChannels (#11728) --- comfy_extras/nodes_audio.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 94ad5e8a8..15b3aa401 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode): separate = execute # TODO: remove +class JoinAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="JoinAudioChannels", + display_name="Join Audio Channels", + description="Joins left and right mono audio channels into a stereo audio.", + category="audio", + inputs=[ + IO.Audio.Input("audio_left"), + IO.Audio.Input("audio_right"), + ], + outputs=[ + IO.Audio.Output(display_name="audio"), + ], + ) + + @classmethod + def execute(cls, audio_left, audio_right) -> IO.NodeOutput: + waveform_left = audio_left["waveform"] + sample_rate_left = audio_left["sample_rate"] + waveform_right = audio_right["waveform"] + sample_rate_right = audio_right["sample_rate"] + + if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1: + raise ValueError("AudioJoin: Both input audios must be mono.") + + # Handle different sample rates by resampling to the higher rate + waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates( + waveform_left, sample_rate_left, waveform_right, sample_rate_right + ) + + # Handle different lengths by trimming to the shorter length + length_left = waveform_left.shape[-1] + length_right = waveform_right.shape[-1] + + if length_left != length_right: + min_length = min(length_left, length_right) + if length_left > min_length: + logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.") + waveform_left = waveform_left[..., :min_length] + if length_right > min_length: + logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.") + waveform_right = waveform_right[..., :min_length] + + # Join the channels into stereo + left_channel = waveform_left[..., 0:1, :] + right_channel = waveform_right[..., 0:1, :] + stereo_waveform = torch.cat([left_channel, right_channel], dim=1) + + return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate}) + def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): if sample_rate_1 != sample_rate_2: @@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension): RecordAudio, TrimAudioDuration, SplitAudioChannels, + JoinAudioChannels, AudioConcat, AudioMerge, AudioAdjustVolume, From b48d6a83d4f7012a1b6f6f41e66b0ac3f3253b8a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:15:50 -0800 Subject: [PATCH 033/119] Fix csp error in frontend when forcing offline. (#11749) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 70c8b5e3b..4db3347cb 100644 --- a/server.py +++ b/server.py @@ -184,7 +184,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';" return response return block_external_middleware From 114fc73685129bf4e8ddced432247fe67dc6fbff Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Fri, 9 Jan 2026 12:16:15 +0900 Subject: [PATCH 034/119] Bump comfyui-frontend-package to 1.36.13 (#11645) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 49567ad61..7686a5f8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.35.9 +comfyui-frontend-package==1.36.13 comfyui-workflow-templates==0.7.69 comfyui-embedded-docs==0.3.1 torch From 1dc3da631423b776669a6a9128bb1aeaf5592c55 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Jan 2026 19:21:51 -0800 Subject: [PATCH 035/119] Add most basic Asset support for models (#11315) * Brought over minimal elements from PR 10045 to reproduce seed_assets and register_assets_system without adding anything to the DB or server routes yet, for now making everything sync (can introduce async once everything is cleaned up and brought over) * Added db script to insert assets stuff, cleaned up some code; assets (models) now get added/rescanned * Added support for 5 http endpoints for assets * Replaced Optional with | None in schemas_in.py and schemas_out.py * Remove two routes that will not be relevant yet in this PR: HEAD /api/assets/hash/ and PUT /api/assets//preview * Remove some functions the two deleted endpoints were using * Don't show assets scan message upon calling /object_info endpoint * removed unsued import to satisfy ruff * Simplified hashing function tpye hint and _hash_file_obj * Satisfied ruff --- alembic_db/versions/0001_assets.py | 174 +++++++++++++++++++ app/assets/api/routes.py | 102 +++++++++++ app/assets/api/schemas_in.py | 94 ++++++++++ app/assets/api/schemas_out.py | 60 +++++++ app/assets/database/bulk_ops.py | 188 ++++++++++++++++++++ app/assets/database/models.py | 233 +++++++++++++++++++++++++ app/assets/database/queries.py | 267 +++++++++++++++++++++++++++++ app/assets/database/tags.py | 62 +++++++ app/assets/hashing.py | 75 ++++++++ app/assets/helpers.py | 216 +++++++++++++++++++++++ app/assets/manager.py | 123 +++++++++++++ app/assets/scanner.py | 229 +++++++++++++++++++++++++ app/database/models.py | 25 ++- comfy/cli_args.py | 1 + main.py | 3 + server.py | 4 + 16 files changed, 1847 insertions(+), 9 deletions(-) create mode 100644 alembic_db/versions/0001_assets.py create mode 100644 app/assets/api/routes.py create mode 100644 app/assets/api/schemas_in.py create mode 100644 app/assets/api/schemas_out.py create mode 100644 app/assets/database/bulk_ops.py create mode 100644 app/assets/database/models.py create mode 100644 app/assets/database/queries.py create mode 100644 app/assets/database/tags.py create mode 100644 app/assets/hashing.py create mode 100644 app/assets/helpers.py create mode 100644 app/assets/manager.py create mode 100644 app/assets/scanner.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000..1e10b94dc --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,174 @@ +""" +Initial assets schema +Revision ID: 0001_assets +Revises: None +Create Date: 2025-12-10 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity + op.create_table( + "assets", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + + # ASSETS_INFO: user-visible references + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=512), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_CACHE_STATE: N:1 local cache rows per Asset + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary + tags_table = sa.table( + "tags", + sa.column("name", sa.String(length=512)), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, + + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + + {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py new file mode 100644 index 000000000..30e87a898 --- /dev/null +++ b/app/assets/api/routes.py @@ -0,0 +1,102 @@ +import logging +import uuid +from aiohttp import web + +from pydantic import ValidationError + +import app.assets.manager as manager +from app import user_manager +from app.assets.api import schemas_in +from app.assets.helpers import get_query_dict + +ROUTES = web.RouteTableDef() +USER_MANAGER: user_manager.UserManager | None = None + +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global USER_MANAGER + USER_MANAGER = user_manager_instance + app.add_routes(ROUTES) + +def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + + +def _validation_error_response(code: str, ve: ValidationError) -> web.Response: + return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + """ + GET request to list assets. + """ + query_dict = get_query_dict(request) + try: + q = schemas_in.ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) + + payload = manager.list_assets( + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=q.sort, + order=q.order, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(payload.model_dump(mode="json")) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +async def get_asset(request: web.Request) -> web.Response: + """ + GET request to get an asset's info as JSON. + """ + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + result = manager.get_asset( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as e: + return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + except Exception: + logging.exception( + "get_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + """ + GET request to list all tags based on query parameters. + """ + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as e: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, + status=400, + ) + + result = manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(result.model_dump(mode="json")) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py new file mode 100644 index 000000000..200b41aef --- /dev/null +++ b/app/assets/api/schemas_in.py @@ -0,0 +1,94 @@ +import json +import uuid +from typing import Any, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, +) + + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: dict[str, Any] | None = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: str | None = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: str | None) -> str | None: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class SetPreviewBody(BaseModel): + """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" + preview_id: str | None = None + + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py new file mode 100644 index 000000000..9f8184f20 --- /dev/null +++ b/app/assets/api/schemas_out.py @@ -0,0 +1,60 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + preview_url: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetDetail(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + preview_id: str | None = None + created_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py new file mode 100644 index 000000000..9352cd65d --- /dev/null +++ b/app/assets/database/bulk_ops.py @@ -0,0 +1,188 @@ +import os +import uuid +import sqlalchemy +from typing import Iterable +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import utcnow +from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta + +MAX_BIND_PARAMS = 800 + +def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: + if not rows: + return [] + rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i:i + rows_per_stmt] + +def _iter_chunks(seq, n: int): + for i in range(0, len(seq), n): + yield seq[i:i + n] + +def _rows_per_stmt(cols: int) -> int: + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def seed_from_paths_batch( + session: Session, + *, + specs: list[dict], + owner_id: str = "", +) -> dict: + """Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + """ + if not specs: + return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} + + now = utcnow() + asset_rows: list[dict] = [] + state_rows: list[dict] = [] + path_to_asset: dict[str, str] = {} + asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row + path_list: list[str] = [] + + for sp in specs: + ap = os.path.abspath(sp["abs_path"]) + aid = str(uuid.uuid4()) + iid = str(uuid.uuid4()) + path_list.append(ap) + path_to_asset[ap] = aid + + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) + asset_to_info[aid] = { + "id": iid, + "owner_id": owner_id, + "name": sp["info_name"], + "asset_id": aid, + "preview_id": None, + "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + "_tags": sp["tags"], + "_filename": sp["fname"], + } + + # insert all seed Assets (hash=NULL) + ins_asset = sqlite.insert(Asset) + for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): + session.execute(ins_asset, chunk) + + # try to claim AssetCacheState (file_path) + winners_by_path: set[str] = set() + ins_state = ( + sqlite.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + .returning(AssetCacheState.file_path) + ) + for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): + winners_by_path.update((session.execute(ins_state, chunk)).scalars().all()) + + all_paths_set = set(path_list) + losers_by_path = all_paths_set - winners_by_path + lost_assets = [path_to_asset[p] for p in losers_by_path] + if lost_assets: # losers get their Asset removed + for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): + session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) + + if not winners_by_path: + return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + + # insert AssetInfo only for winners + winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] + ins_info = ( + sqlite.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + .returning(AssetInfo.id) + ) + + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): + inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all()) + + # build and insert tag + meta rows for the AssetInfo + tag_rows: list[dict] = [] + meta_rows: list[dict] = [] + if inserted_info_ids: + for row in winner_info_rows: + iid = row["id"] + if iid not in inserted_info_ids: + continue + for t in row["_tags"]: + tag_rows.append({ + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + }) + if row["_filename"]: + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) + return { + "inserted_infos": len(inserted_info_ids), + "won_states": len(winners_by_path), + "lost_states": len(losers_by_path), + } + + +def bulk_insert_tags_and_meta( + session: Session, + *, + tag_rows: list[dict], + meta_rows: list[dict], + max_bind_params: int, +) -> None: + """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. + - tag_rows keys: asset_info_id, tag_name, origin, added_at + - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + if tag_rows: + ins_links = ( + sqlite.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): + session.execute(ins_links, chunk) + if meta_rows: + ins_meta = ( + sqlite.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py new file mode 100644 index 000000000..3cd28f68b --- /dev/null +++ b/app/assets/database/models.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import uuid +from datetime import datetime + +from typing import Any +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship + +from app.assets.helpers import utcnow +from app.database.models import to_dict, Base + + +class Asset(Base): + __tablename__ = "assets" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[str | None] = mapped_column(String(256), nullable=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + infos: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + preview_of: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], + viewonly=True, + ) + + cache_states: Mapped[list[AssetCacheState]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), + Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) + file_path: Mapped[str] = mapped_column(Text, nullable=False) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + asset: Mapped[Asset] = relationship(back_populates="cache_states") + + __table_args__ = ( + Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_id", "asset_id"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_id], + lazy="selectin", + ) + preview_asset: Mapped[Asset | None] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_id], + ) + + metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list[Tag]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="selectin", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_id", "asset_id"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True) + val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(512), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list[AssetInfo]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py new file mode 100644 index 000000000..0824c0c2f --- /dev/null +++ b/app/assets/database/queries.py @@ -0,0 +1,267 @@ +import sqlalchemy as sa +from collections import defaultdict +from sqlalchemy import select, exists, func +from sqlalchemy.orm import Session, contains_eager, noload +from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.helpers import escape_like_prefix, normalize_tags +from typing import Sequence + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + +def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None: + return session.get(AssetInfo, asset_info_id) + +def list_asset_infos_page( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((session.execute(count_stmt)).scalar_one() or 0) + + infos = (session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + +def fetch_asset_info_asset_and_tags( + session: Session, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset, list[str]] | None: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py new file mode 100644 index 000000000..3ab6497c2 --- /dev/null +++ b/app/assets/database/tags.py @@ -0,0 +1,62 @@ +from typing import Iterable + +import sqlalchemy +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import normalize_tags, utcnow +from app.assets.database.models import Tag, AssetInfoTag, AssetInfo + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + return session.execute(ins) + +def add_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sqlalchemy.select( + AssetInfo.id.label("asset_info_id"), + sqlalchemy.literal("missing").label("tag_name"), + sqlalchemy.literal(origin).label("origin"), + sqlalchemy.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sqlalchemy.not_( + sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + ) + ) + ) + session.execute( + sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sqlalchemy.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py new file mode 100644 index 000000000..4b72084b9 --- /dev/null +++ b/app/assets/hashing.py @@ -0,0 +1,75 @@ +from blake3 import blake3 +from typing import IO +import os +import asyncio + + +DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB + +# NOTE: this allows hashing different representations of a file-like object +def blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """ + Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + # duck typing to check if input is a file-like object + if hasattr(fp, "read"): + return _hash_file_obj(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + +async def blake3_hash_async( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + return await asyncio.to_thread(_worker) + + +def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: + """ + Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + # in case file object is already open and not at the beginning, track so can be restored after hashing + orig_pos = file_obj.tell() + + try: + # seek to the beginning before reading + if orig_pos != 0: + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + # restore original position in file object, if needed + if orig_pos != 0: + file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py new file mode 100644 index 000000000..6755d0e56 --- /dev/null +++ b/app/assets/helpers.py @@ -0,0 +1,216 @@ +import contextlib +import os +from aiohttp import web +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, Any + +import folder_paths + + +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") + +def get_query_dict(request: web.Request) -> dict[str, Any]: + """ + Gets a dictionary of query parameters from the request. + + 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) + for key in request.query.keys() + } + return query_dict + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + +def prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape + +def fast_asset_file_check( + *, + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + if mtime_db is None: + return False + actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + if int(mtime_db) != int(actual_mtime_ns): + return False + sz = int(size_db or 0) + if sz > 0: + return int(stat_result.st_size) == sz + return True + +def utcnow() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. + + We trust `folder_paths.folder_names_and_paths` and include a category if + *any* of its base paths lies under the Comfy `models_dir`. + """ + targets: list[tuple[str, list[str]]] = [] + models_root = os.path.abspath(folder_paths.models_dir) + for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + return targets + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. + """ + try: + root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: + """Given an absolute or relative file path, determine which root category the path belongs to: + - 'input' if the file resides under `folder_paths.get_input_directory()` + - 'output' if the file resides under `folder_paths.get_output_directory()` + - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` + + Returns: + (root_category, relative_path_inside_that_root) + For 'models', the relative path is prefixed with the category name: + e.g. ('models', 'vae/test/sub/ae.safetensors') + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + fp_abs = os.path.abspath(file_path) + + def _is_within(child: str, parent: str) -> bool: + try: + return os.path.commonpath([child, parent]) == parent + except Exception: + return False + + def _rel(child: str, parent: str) -> str: + return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _is_within(fp_abs, input_base): + return "input", _rel(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _is_within(fp_abs, output_base): + return "output", _rel(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return a tuple (name, tags) derived from a filesystem path. + + Semantics: + - Root category is determined by `get_relative_to_root_category_path_of_asset`. + - The returned `name` is the base filename with extension from the relative path. + - The returned `tags` are: + [root_category] + parent folders of the relative path (in order) + For 'models', this means: + file '/.../ModelsDir/vae/test_tag/ae.safetensors' + -> root_category='models', some_path='vae/test_tag/ae.safetensors' + -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) + p = Path(some_path) + parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + +def normalize_tags(tags: list[str] | None) -> list[str]: + """ + Normalize a list of tags by: + - Stripping whitespace and converting to lowercase. + - Removing duplicates. + """ + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out diff --git a/app/assets/manager.py b/app/assets/manager.py new file mode 100644 index 000000000..6425e7aa2 --- /dev/null +++ b/app/assets/manager.py @@ -0,0 +1,123 @@ +from typing import Sequence + +from app.database.db import create_session +from app.assets.api import schemas_out +from app.assets.database.queries import ( + asset_exists_by_hash, + fetch_asset_info_asset_and_tags, + list_asset_infos_page, + list_tags_with_usage, +) + + +def _safe_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + +def list_assets( + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", + owner_id: str = "", +) -> schemas_out.AssetsList: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + with create_session() as session: + infos, tag_map, total = list_asset_infos_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + summaries: list[schemas_out.AssetSummary] = [] + for info in infos: + asset = info.asset + tags = tag_map.get(info.id, []) + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + preview_url=f"/api/assets/{info.id}/content", + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) + ) + + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) + +def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: + with create_session() as session: + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise ValueError(f"AssetInfo {asset_info_id} not found") + info, asset, tag_names = res + preview_id = info.preview_id + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py new file mode 100644 index 000000000..a16e41d94 --- /dev/null +++ b/app/assets/scanner.py @@ -0,0 +1,229 @@ +import contextlib +import time +import logging +import os +import sqlalchemy + +import folder_paths +from app.database.db import create_session, dependencies_available +from app.assets.helpers import ( + collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, + list_tree,prefixes_for_root, escape_like_prefix, + RootType +) +from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id +from app.assets.database.bulk_ops import seed_from_paths_batch +from app.assets.database.models import Asset, AssetCacheState, AssetInfo + + +def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: + """ + Scan the given roots and seed the assets into the database. + """ + if not dependencies_available(): + if enable_logging: + logging.warning("Database dependencies not available, skipping assets scan") + return + t_start = time.perf_counter() + created = 0 + skipped_existing = 0 + paths: list[str] = [] + try: + existing_paths: set[str] = set() + for r in roots: + try: + survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) + if survivors: + existing_paths.update(survivors) + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", r, e) + + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + specs: list[dict] = [] + tag_pool: set[str] = set() + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped_existing += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=False) + except OSError: + continue + # skip empty files + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(abs_p), + } + ) + for t in tags: + tag_pool.add(t) + # if no file specs, nothing to do + if not specs: + return + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + + result = seed_from_paths_batch(sess, specs=specs, owner_id="") + created += result["inserted_infos"] + sess.commit() + finally: + if enable_logging: + logging.info( + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", + roots, + time.perf_counter() - t_start, + created, + skipped_existing, + len(paths), + ) + + +def _fast_db_consistency_pass( + root: RootType, + *, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> set[str] | None: + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return set() if collect_existing_paths else None + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + with create_session() as sess: + rows = ( + sess.execute( + sqlalchemy.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sqlalchemy.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).all() + + by_asset: dict[str, dict] = {} + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + fast_ok = False + try: + exists = True + fast_ok = fast_asset_file_check( + mtime_db=mtime_db, + size_db=acc["size_db"], + stat_result=os.stat(fp, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except OSError: + exists = False + + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) + + to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) + + if a_hash is None: + if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = sess.get(Asset, aid) + if asset: + sess.delete(asset) + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: + with contextlib.suppress(Exception): + remove_missing_tag_for_asset_id(sess, asset_id=aid) + elif update_missing_tags: + with contextlib.suppress(Exception): + add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) + if to_set_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) + if to_clear_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) + sess.commit() + return survivors if collect_existing_paths else None diff --git a/app/database/models.py b/app/database/models.py index 6facfb8f2..e7572677a 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,14 +1,21 @@ -from sqlalchemy.orm import declarative_base +from typing import Any +from datetime import datetime +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() +class Base(DeclarativeBase): + pass - -def to_dict(obj): +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out # TODO: Define models here diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dae9a895d..1716c3de7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -231,6 +231,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/main.py b/main.py index 0e07a95da..37b06c1fa 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import folder_paths import time from comfy.cli_args import args from app.logger import setup_logger +from app.assets.scanner import seed_assets import itertools import utils.extra_config import logging @@ -324,6 +325,8 @@ def setup_database(): from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() + if not args.disable_assets_autoscan: + seed_assets(["models"], enable_logging=True) except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") diff --git a/server.py b/server.py index 4db3347cb..da2baefd4 100644 --- a/server.py +++ b/server.py @@ -33,6 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal +from app.assets.scanner import seed_assets +from app.assets.api.routes import register_assets_system from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -235,6 +237,7 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -683,6 +686,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): + seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: From 6207f86c18d2cf2d70ab059987b62d4b38466e77 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 8 Jan 2026 20:34:48 -0800 Subject: [PATCH 036/119] Fix VAEEncodeForInpaint to support WAN VAE tuple downscale_ratio (#11572) Use vae.spacial_compression_encode() instead of directly accessing downscale_ratio to handle both standard VAEs (int) and WAN VAEs (tuple). Addresses reviewer feedback on PR #11259. Co-authored-by: ChrisFab16 --- nodes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 56b74ebe3..1aa391f4a 100644 --- a/nodes.py +++ b/nodes.py @@ -378,14 +378,15 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio - y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + downscale_ratio = vae.spacial_compression_encode() + x = (pixels.shape[1] // downscale_ratio) * downscale_ratio + y = (pixels.shape[2] // downscale_ratio) * downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 - y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 + x_offset = (pixels.shape[1] % downscale_ratio) // 2 + y_offset = (pixels.shape[2] % downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] From 4609fcd26081156eef921bd9f43726f670ee6f51 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 9 Jan 2026 00:31:19 -0500 Subject: [PATCH 037/119] add node - image compare (#11343) --- comfy_api/latest/_io.py | 13 +++++++ comfy_extras/nodes_image_compare.py | 53 +++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 67 insertions(+) create mode 100644 comfy_extras/nodes_image_compare.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 764fa8b2b..50143ff53 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1113,6 +1113,18 @@ class DynamicSlot(ComfyTypeI): out_dict[input_type][finalized_id] = value out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) +@comfytype(io_type="IMAGECOMPARE") +class ImageCompare(ComfyTypeI): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True): + super().__init__(id, display_name, optional, tooltip, None, None, socketless) + + def as_dict(self): + return super().as_dict() + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -1958,4 +1970,5 @@ __all__ = [ "add_to_dict_v1", "add_to_dict_v3", "V3Data", + "ImageCompare", ] diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py new file mode 100644 index 000000000..8e9f809e6 --- /dev/null +++ b/comfy_extras/nodes_image_compare.py @@ -0,0 +1,53 @@ +import nodes + +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension + + +class ImageCompare(IO.ComfyNode): + """Compares two images with a slider interface.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCompare", + display_name="Image Compare", + description="Compares two images side by side with a slider.", + category="image", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.Image.Input("image_a", optional=True), + IO.Image.Input("image_b", optional=True), + IO.ImageCompare.Input("compare_view"), + ], + outputs=[], + ) + + @classmethod + def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput: + result = {"a_images": [], "b_images": []} + + preview_node = nodes.PreviewImage() + + if image_a is not None and len(image_a) > 0: + saved = preview_node.save_images(image_a, "comfy.compare.a") + result["a_images"] = saved["ui"]["images"] + + if image_b is not None and len(image_b) > 0: + saved = preview_node.save_images(image_b, "comfy.compare.b") + result["b_images"] = saved["ui"]["images"] + + return IO.NodeOutput(ui=result) + + +class ImageCompareExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCompare, + ] + + +async def comfy_entrypoint() -> ImageCompareExtension: + return ImageCompareExtension() diff --git a/nodes.py b/nodes.py index 1aa391f4a..5a9d42d4a 100644 --- a/nodes.py +++ b/nodes.py @@ -2370,6 +2370,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_image_compare.py", ] import_failed = [] From 04c49a29b493f3f9037b83cec45f6369b5c4816b Mon Sep 17 00:00:00 2001 From: ric-yu Date: Thu, 8 Jan 2026 21:57:36 -0800 Subject: [PATCH 038/119] feat: add cancelled filter to /jobs (#11680) --- comfy_execution/jobs.py | 31 +++++++++++++++++------------ tests/execution/test_jobs.py | 38 +++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 59fb49357..97fd803b8 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -14,8 +14,9 @@ class JobStatus: IN_PROGRESS = 'in_progress' COMPLETED = 'completed' FAILED = 'failed' + CANCELLED = 'cancelled' - ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] # Media types that can be previewed in the frontend @@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: status_info = history_item.get('status', {}) status_str = status_info.get('status_str') if status_info else None - if status_str == 'success': - status = JobStatus.COMPLETED - elif status_str == 'error': - status = JobStatus.FAILED - else: - status = JobStatus.COMPLETED outputs = history_item.get('outputs', {}) outputs_count, preview_output = get_outputs_summary(outputs) @@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_error = None execution_start_time = None execution_end_time = None + was_interrupted = False if status_info: messages = status_info.get('messages', []) for entry in messages: @@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_end_time = event_data.get('timestamp') if event_name == 'execution_error': execution_error = event_data + elif event_name == 'execution_interrupted': + was_interrupted = True + + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED + else: + status = JobStatus.COMPLETED job = prune_dict({ 'id': prompt_id, @@ -268,13 +273,13 @@ def get_all_jobs( for item in queued: jobs.append(normalize_queue_item(item, JobStatus.PENDING)) - include_completed = JobStatus.COMPLETED in status_filter - include_failed = JobStatus.FAILED in status_filter - if include_completed or include_failed: + history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} + requested_history_statuses = history_statuses & set(status_filter) + if requested_history_statuses: for prompt_id, history_item in history.items(): - is_failed = history_item.get('status', {}).get('status_str') == 'error' - if (is_failed and include_failed) or (not is_failed and include_completed): - jobs.append(normalize_history_item(prompt_id, history_item)) + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 918c8080a..4d2f9ed36 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -19,6 +19,7 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS == 'in_progress' assert JobStatus.COMPLETED == 'completed' assert JobStatus.FAILED == 'failed' + assert JobStatus.CANCELLED == 'cancelled' def test_all_contains_all_statuses(self): """ALL should contain all status values.""" @@ -26,7 +27,8 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS in JobStatus.ALL assert JobStatus.COMPLETED in JobStatus.ALL assert JobStatus.FAILED in JobStatus.ALL - assert len(JobStatus.ALL) == 4 + assert JobStatus.CANCELLED in JobStatus.ALL + assert len(JobStatus.ALL) == 5 class TestIsPreviewable: @@ -336,6 +338,40 @@ class TestNormalizeHistoryItem: assert job['execution_error']['node_type'] == 'KSampler' assert job['execution_error']['exception_message'] == 'CUDA out of memory' + def test_cancelled_job(self): + """Cancelled/interrupted history item should have cancelled status.""" + history_item = { + 'prompt': ( + 5, + 'prompt-cancelled', + {'nodes': {}}, + {'create_time': 1234567890000}, + ['node1'], + ), + 'status': { + 'status_str': 'error', + 'completed': False, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}), + ('execution_interrupted', { + 'prompt_id': 'prompt-cancelled', + 'node_id': '5', + 'node_type': 'KSampler', + 'executed': ['1', '2', '3'], + 'timestamp': 1234567891000, + }) + ] + }, + 'outputs': {}, + } + + job = normalize_history_item('prompt-cancelled', history_item) + assert job['status'] == 'cancelled' + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567891000 + # Cancelled jobs should not have execution_error set + assert 'execution_error' not in job + def test_include_outputs(self): """When include_outputs=True, should include full output data.""" history_item = { From ec0a832acb25fbe53bd4fc25d286a9ee442a3bcf Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Jan 2026 22:49:12 -0800 Subject: [PATCH 039/119] Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755) --- app/assets/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 6755d0e56..08b465b5a 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -81,7 +81,8 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: """ targets: list[tuple[str, list[str]]] = [] models_root = os.path.abspath(folder_paths.models_dir) - for name, (paths, _exts) in folder_paths.folder_names_and_paths.items(): + for name, values in folder_paths.folder_names_and_paths.items(): + paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): targets.append((name, paths)) return targets From bd0e6825e84606e0706bbb5645e9ea1f4ad8154d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 Jan 2026 11:21:06 -0800 Subject: [PATCH 040/119] Be less strict when loading mixed ops weights. (#11769) --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8156c42ff..1cf22f0cc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -546,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_key = f"{prefix}weight" weight = state_dict.pop(weight_key, None) if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + logging.warning(f"Missing weight for layer {layer_name}") + return manually_loaded_keys = [weight_key] From 4484b93d615059012d3a5ce91d1dbbb0cbaa2d76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:25:56 +0200 Subject: [PATCH 041/119] fix(api-nodes): do not downscale the input image for Topaz Enhance (#11768) --- comfy_api_nodes/nodes_topaz.py | 7 ++++--- comfy_api_nodes/util/upload_helpers.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index b04575ad8..9dc5f45bc 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -2,7 +2,6 @@ import builtins from io import BytesIO import aiohttp -import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str = "", subject_detection: str = "All", face_enhancement: bool = True, @@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") - download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + download_url = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096 + ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index f1ed7fe9c..2535a0884 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -49,6 +49,7 @@ async def upload_images_to_comfyapi( mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, + total_pixels: int = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -63,7 +64,7 @@ async def upload_images_to_comfyapi( for idx in range(num_to_upload): tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label if wait_label and show_batch_index and num_to_upload > 1: From 393d2880ddc6e57c0fa3b878bb50fa2901bd57e6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:59:38 +0200 Subject: [PATCH 042/119] feat(api-nodes): added nodes for Vidu2 (#11760) --- comfy_api_nodes/apis/vidu.py | 41 +++ comfy_api_nodes/nodes_vidu.py | 588 +++++++++++++++++++++++++--------- 2 files changed, 482 insertions(+), 147 deletions(-) create mode 100644 comfy_api_nodes/apis/vidu.py diff --git a/comfy_api_nodes/apis/vidu.py b/comfy_api_nodes/apis/vidu.py new file mode 100644 index 000000000..a9bb6f7ce --- /dev/null +++ b/comfy_api_nodes/apis/vidu.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field + + +class SubjectReference(BaseModel): + id: str = Field(...) + images: list[str] = Field(...) + + +class TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + aspect_ratio: str | None = Field(None) + resolution: str | None = Field(None) + movement_amplitude: str | None = Field(None) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + subjects: list[SubjectReference] | None = Field(None) + bgm: bool | None = Field(None) + audio: bool | None = Field(None) + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: str = Field(...) + created_at: str = Field(...) + code: int | None = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: str = Field(...) + err_code: str | None = Field(None) + progress: float | None = Field(None) + credits: int | None = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 7a679f0d9..9d94ae7ad 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,12 +1,13 @@ -import logging -from enum import Enum -from typing import Literal, Optional, TypeVar - -import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.vidu import ( + SubjectReference, + TaskCreationRequest, + TaskCreationResponse, + TaskResult, + TaskStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_video_output, @@ -17,6 +18,7 @@ from comfy_api_nodes.util import ( validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, + validate_string, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" -R = TypeVar("R") - - -class VideoModelName(str, Enum): - vidu_q1 = "viduq1" - - -class AspectRatio(str, Enum): - r_16_9 = "16:9" - r_9_16 = "9:16" - r_1_1 = "1:1" - - -class Resolution(str, Enum): - r_1080p = "1080p" - - -class MovementAmplitude(str, Enum): - auto = "auto" - small = "small" - medium = "medium" - large = "large" - - -class TaskCreationRequest(BaseModel): - model: VideoModelName = VideoModelName.vidu_q1 - prompt: Optional[str] = Field(None, max_length=1500) - duration: Optional[Literal[5]] = 5 - seed: Optional[int] = Field(0, ge=0, le=2147483647) - aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 - resolution: Optional[Resolution] = Resolution.r_1080p - movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto - images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") - - -class TaskCreationResponse(BaseModel): - task_id: str = Field(...) - state: str = Field(...) - created_at: str = Field(...) - code: Optional[int] = Field(None, description="Error code") - - -class TaskResult(BaseModel): - id: str = Field(..., description="Creation id") - url: str = Field(..., description="The URL of the generated results, valid for one hour") - cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") - - -class TaskStatusResponse(BaseModel): - state: str = Field(...) - err_code: Optional[str] = Field(None) - creations: list[TaskResult] = Field(..., description="Generated results") - - -def get_video_url_from_response(response) -> Optional[str]: - if response.creations: - return response.creations[0].url - return None - - -def get_video_from_response(response) -> TaskResult: - if not response.creations: - error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" - logging.info(error_msg) - raise RuntimeError(error_msg) - logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) - return response.creations[0] - async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, payload: TaskCreationRequest, - estimated_duration: int, -) -> R: - response = await sync_op( +) -> list[TaskResult]: + task_creation_response = await sync_op( cls, endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), response_model=TaskCreationResponse, data=payload, ) - if response.state == "failed": - error_msg = f"Vidu request failed. Code: {response.code}" - logging.error(error_msg) - raise RuntimeError(error_msg) - return await poll_op( + if task_creation_response.state == "failed": + raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}") + response = await poll_op( cls, - ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id), response_model=TaskStatusResponse, status_extractor=lambda r: r.state, - estimated_duration=estimated_duration, + progress_extractor=lambda r: r.progress, + max_poll_attempts=320, ) + if not response.creations: + raise RuntimeError( + f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + ) + return response.creations class ViduTextToVideoNode(IO.ComfyNode): @@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode): node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", category="api node/video/Vidu", - description="Generate video from text prompt", + description="Generate video from a text prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.String.Input( "prompt", multiline=True, @@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -208,7 +137,7 @@ class ViduTextToVideoNode(IO.ComfyNode): if not prompt: raise ValueError("The prompt field is required and cannot be empty.") payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -216,8 +145,8 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduImageToVideoNode(IO.ComfyNode): @@ -230,12 +159,7 @@ class ViduImageToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate video from image and optional prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "image", tooltip="An image to be used as the start frame of the generated video", @@ -270,15 +194,13 @@ class ViduImageToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -298,7 +220,7 @@ class ViduImageToVideoNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, duration: int, seed: int, @@ -309,7 +231,7 @@ class ViduImageToVideoNode(IO.ComfyNode): raise ValueError("Only one input image is allowed.") validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -322,8 +244,8 @@ class ViduImageToVideoNode(IO.ComfyNode): max_images=1, mime_type="image/png", ) - results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduReferenceVideoNode(IO.ComfyNode): @@ -334,14 +256,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", category="api node/video/Vidu", - description="Generate video from multiple images and prompt", + description="Generate video from multiple images and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "images", tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", @@ -374,22 +291,19 @@ class ViduReferenceVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -409,7 +323,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): async def execute( cls, model: str, - images: torch.Tensor, + images: Input.Image, prompt: str, duration: int, seed: int, @@ -426,7 +340,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -440,8 +354,8 @@ class ViduReferenceVideoNode(IO.ComfyNode): max_images=7, mime_type="image/png", ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduStartEndToVideoNode(IO.ComfyNode): @@ -454,12 +368,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "first_frame", tooltip="Start frame", @@ -497,15 +406,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -525,8 +432,8 @@ class ViduStartEndToVideoNode(IO.ComfyNode): async def execute( cls, model: str, - first_frame: torch.Tensor, - end_frame: torch.Tensor, + first_frame: Input.Image, + end_frame: Input.Image, prompt: str, duration: int, seed: int, @@ -535,7 +442,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -546,8 +453,391 @@ class ViduStartEndToVideoNode(IO.ComfyNode): (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2TextToVideoNode", + display_name="Vidu2 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Boolean.Input( + "background_music", + default=False, + tooltip="Whether to add background music to the generated video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + background_music: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + bgm=background_music, + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ImageToVideoNode", + display_name="Vidu2 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + ), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + ), + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ReferenceVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ReferenceVideoNode", + display_name="Vidu2 Reference-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from multiple reference images and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.Autogrow.Input( + "subjects", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_images"), + names=["subject1", "subject2", "subject3"], + min=1, + ), + tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " + "Reference them in prompts via @subject{subject_id}.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="When enabled, the video will include generated speech and background music " + "based on the prompt.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled video will contain generated speech and background music based on the prompt.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), + IO.Combo.Input("resolution", options=["720p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + subjects: IO.Autogrow.Type, + prompt: str, + audio: bool, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + total_images = 0 + for i in subjects: + if get_number_of_images(subjects[i]) > 3: + raise ValueError("Maximum number of images per subject is 3.") + for im in subjects[i]: + total_images += 1 + validate_image_aspect_ratio(im, (1, 4), (4, 1)) + validate_image_dimensions(im, min_width=128, min_height=128) + if total_images > 7: + raise ValueError("Too many reference images; the maximum allowed is 7.") + subjects_param: list[SubjectReference] = [] + for i in subjects: + subjects_param.append( + SubjectReference( + id=i, + images=await upload_images_to_comfyapi( + cls, + subjects[i], + max_images=3, + mime_type="image/png", + wait_label=f"Uploading reference images for {i}", + ), + ), + ) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + audio=audio, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + subjects=subjects_param, + ) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2StartEndToVideoNode", + display_name="Vidu2 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=2, + max=8, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + if get_number_of_images(first_frame) > 1: + raise ValueError("Only one input image is allowed for `first_frame`.") + if get_number_of_images(end_frame) > 1: + raise ValueError("Only one input image is allowed for `end_frame`.") + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduExtension(ComfyExtension): @@ -558,6 +848,10 @@ class ViduExtension(ComfyExtension): ViduImageToVideoNode, ViduReferenceVideoNode, ViduStartEndToVideoNode, + Vidu2TextToVideoNode, + Vidu2ImageToVideoNode, + Vidu2ReferenceVideoNode, + Vidu2StartEndToVideoNode, ] From 153bc524bf9db76d723289f6420f418f63966972 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 10 Jan 2026 14:29:30 +0800 Subject: [PATCH 043/119] chore: update embedded docs to v0.4.0 (#11776) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7686a5f8a..6c1cd86d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.36.13 comfyui-workflow-templates==0.7.69 -comfyui-embedded-docs==0.3.1 +comfyui-embedded-docs==0.4.0 torch torchsde torchvision From dc202a2e51bf7a6cd00e606b2d2941bc223f2ad2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:03:57 -0800 Subject: [PATCH 044/119] Properly save mixed ops. (#11772) --- comfy/ops.py | 26 ++++++++++++------- .../comfy_quant/test_mixed_precision.py | 6 ++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 1cf22f0cc..9c0b54ff4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(key) def state_dict(self, *args, destination=None, prefix="", **kwargs): - sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) - if isinstance(self.weight, QuantizedTensor): - layout_cls = self.weight._layout_cls + if destination is not None: + sd = destination + else: + sd = {} - # Check if it's any FP8 variant (E4M3 or E5M2) - if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): - sd["{}weight_scale".format(prefix)] = self.weight._params.scale - elif layout_cls == "TensorCoreNVFP4Layout": - sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale - sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + if self.bias is not None: + sd["{}bias".format(prefix)] = self.bias + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] quant_conf = {"format": self.quant_format} if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + + input_scale = getattr(self, 'input_scale', None) + if input_scale is not None: + sd["{}input_scale".format(prefix)] = input_scale + else: + sd["{}weight".format(prefix)] = self.weight return sd def _forward(self, input, weight, bias): diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7b2eac940..7c740491d 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict2 = model.state_dict() # Verify layer1.weight is a QuantizedTensor with scale preserved - self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") + self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8))) + self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0) + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) From 6e4b1f9d00306fe14d7ca5adf2c7468d631b23d5 Mon Sep 17 00:00:00 2001 From: DELUXA Date: Sat, 10 Jan 2026 23:51:05 +0200 Subject: [PATCH 045/119] pythorch_attn_by_def_on_gfx1200 (#11793) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e5de4a5b5..9d39be7b2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -368,7 +368,7 @@ try: if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): + if any((a in arch) for a in ["gfx1200", "gfx1201"]): ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 From cd912963f17c9ae00ec12e1869293edb78720831 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 10 Jan 2026 14:31:31 -0800 Subject: [PATCH 046/119] Fix issue with t5 text encoder in fp4. (#11794) --- comfy/model_detection.py | 2 ++ comfy/sd.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0853b3aec..aff5a50b9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["vec_in_dim"] = None + dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"]) + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma diff --git a/comfy/sd.py b/comfy/sd.py index 5a7221620..b689c0dfc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1059,9 +1059,9 @@ def detect_te_model(sd): return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + if weight.shape[0] == 10240: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight.shape[0] == 5120: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD From 2f642d5d9b48ad7cad13bbdd5f8adcf506f565a7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 10 Jan 2026 14:40:42 -0800 Subject: [PATCH 047/119] Fix chroma fp8 te being treated as fp16. (#11795) --- comfy/text_encoders/cosmos.py | 2 +- comfy/text_encoders/genmo.py | 2 +- comfy/text_encoders/pixart_t5.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 448381fa9..f4b40ac68 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return CosmosTEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 5daea8135..2d7a3fbce 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return MochiTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index e5e5f18be..51c6e50c7 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return PixArtTEModel_ From 5cd1113236b0fb032a51bf9d63ba196a2510b0d4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 11 Jan 2026 13:07:11 +0200 Subject: [PATCH 048/119] fix(api-nodes): use a unique name for uploading audio files (#11778) --- comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/util/conversions.py | 4 ++-- comfy_api_nodes/util/upload_helpers.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9c707a339..01d9c34f5 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -567,7 +567,7 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: audio_url = await upload_audio_to_comfyapi( - cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index d64239c86..99c302a2a 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: str | None = None, + *, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -75,7 +75,7 @@ def tensor_to_bytesio( pil_image = tensor_to_pil(image, total_pixels=total_pixels) img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}" return img_binary diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 2535a0884..cea0d1203 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -82,7 +82,6 @@ async def upload_audio_to_comfyapi( container_format: str = "mp4", codec_name: str = "aac", mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", ) -> str: """ Uploads a single audio input to ComfyUI API and returns its download URL. @@ -92,7 +91,7 @@ async def upload_audio_to_comfyapi( waveform: torch.Tensor = audio["waveform"] audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) async def upload_video_to_comfyapi( From c6238047ee1ffd87eade7c3ab5a8e53c11d4ce39 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 11 Jan 2026 18:11:53 -0800 Subject: [PATCH 049/119] Put more details about portable in readme. (#11816) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6d09758c0..e25f3cda7 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp If you have trouble extracting it, right click the file -> properties -> unblock -Update your Nvidia drivers if it doesn't start. +The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. #### Alternative Downloads: @@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: From a3b5d4996abcd906c7c99f15b69fde051afcb4be Mon Sep 17 00:00:00 2001 From: kelseyee <971704395@qq.com> Date: Tue, 13 Jan 2026 04:38:46 +0800 Subject: [PATCH 050/119] Support ModelScope-Trainer DiffSynth lora for Z Image. (#11805) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 2ed0acb9d..e8246bd66 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to if isinstance(model, comfy.model_base.Kandinsky5): for k in sdk: From c881a1d6897d8fee84559a8e3e49b9116efdb959 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:05:54 -0800 Subject: [PATCH 051/119] Support the siglip 2 naflex model as a clip vision model. (#11831) Not useful yet. --- comfy/clip_model.py | 62 +++++++++++++++++++++- comfy/clip_vision.py | 25 ++++++--- comfy/clip_vision_siglip2_base_naflex.json | 14 +++++ 3 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 comfy/clip_vision_siglip2_base_naflex.json diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e88872728..d7d3f994c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,6 +1,7 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +import math def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): image = image[:, :, :, :3] if image.shape[3] > 3 else image @@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) +def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5): + def scale_dim(size, scale): + scaled = math.ceil(size * scale / patch_size) * patch_size + return max(patch_size, int(scaled)) + + # Binary search for optimal scale + lo, hi = eps / 10, 100.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + h, w = scale_dim(oh, mid), scale_dim(ow, mid) + if (h // patch_size) * (w // patch_size) <= max_num_patches: + lo = mid + else: + hi = mid + + return scale_dim(oh, lo), scale_dim(ow, lo) + +def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True): + if size > 0: + return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop) + + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + + b, c, h, w = image.shape + h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches) + + image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True) + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() @@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module): out = self.text_projection(x[2]) return (x[0], x[1], out, x[2]) +def siglip2_pos_embed(embed_weight, embeds, orig_shape): + embed_weight_len = round(embed_weight.shape[0] ** 0.5) + embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len) + embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True) + embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1) + return embeds + embed_weight + +class Siglip2Embeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None): + super().__init__() + self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) + self.patch_size = patch_size + + def forward(self, pixel_values): + b, c, h, w = pixel_values.shape + img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c) + img = img.permute(0, 1, 3, 2, 4, 5) + img = img.reshape(b, img.shape[1] * img.shape[2], -1) + img = self.patch_embedding(img) + return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size)) class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): @@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module): intermediate_activation = config_dict["hidden_act"] model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) - if model_type == "siglip_vision_model": + if model_type in ["siglip2_vision_model"]: + self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations) + else: + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type in ["siglip_vision_model", "siglip2_vision_model"]: self.pre_layrnorm = lambda a: a self.output_layernorm = True else: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d5fc53497..66f2a9d9c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, } @@ -32,9 +33,10 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_type = config.get("model_type", "clip_vision_model") - model_class = IMAGE_ENCODERS.get(model_type) - if model_type == "siglip_vision_model": + self.model_type = config.get("model_type", "clip_vision_model") + self.config = config.copy() + model_class = IMAGE_ENCODERS.get(self.model_type) + if self.model_type == "siglip_vision_model": self.return_all_hidden_states = True else: self.return_all_hidden_states = False @@ -55,7 +57,10 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + if self.model_type == "siglip2_vision_model": + pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float() + else: + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() @@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - if embed_shape == 729: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif embed_shape == 1024: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape + if len(patch_embedding_shape) == 2: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json") + else: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") diff --git a/comfy/clip_vision_siglip2_base_naflex.json b/comfy/clip_vision_siglip2_base_naflex.json new file mode 100644 index 000000000..6f6b99bd6 --- /dev/null +++ b/comfy/clip_vision_siglip2_base_naflex.json @@ -0,0 +1,14 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": -1, + "intermediate_size": 4304, + "model_type": "siglip2_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "num_patches": 256, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} From fd5c0755af18530ff1225f5946c7a647b9694032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 13 Jan 2026 00:28:59 +0200 Subject: [PATCH 052/119] Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829) --- comfy/ldm/lightricks/av_model.py | 136 ++++++++++++++++++++++++------- 1 file changed, 106 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 759535501..c12ace241 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier import comfy.ldm.common_dit +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space) + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + class BasicAVTransformerBlock(nn.Module): def __init__( self, @@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module): def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + num_ada_params = scale_shift_table.shape[0] ada_values = ( @@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module): gate_timestep, ) - scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] - gate_ada_values = [t.squeeze(2) for t in gate_ada_values] - - return (*scale_shift_chunks, *gate_ada_values) + return (*scale_shift_ada_values, *gate_ada_values) def forward( self, @@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) - v_embedded_timestep = v_embedded_timestep.view( - batch_size, -1, v_embedded_timestep.shape[-1] - ) + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + v_patches_per_frame = None + if orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: - a_timestep = a_timestep * self.timestep_scale_multiplier + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep.flatten(), + timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep.flatten() * av_ca_factor, + timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep.flatten() * av_ca_factor, + a_timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + a_timestep, a_embedded_timestep = self.audio_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) - a_embedded_timestep = a_embedded_timestep.view( - batch_size, -1, a_embedded_timestep.shape[-1] - ) - cross_av_timestep_ss = [ - av_ca_audio_scale_shift_timestep, - av_ca_video_scale_shift_timestep, - av_ca_a2v_gate_noise_timestep, - av_ca_v2a_gate_noise_timestep, - ] - cross_av_timestep_ss = list( - [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] - ) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) else: - a_timestep = timestep + a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] @@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel): ax = x[1] v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) # Process audio output From c2b65e2fceea821276c143ad52478552633922bf Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 13 Jan 2026 06:29:25 +0800 Subject: [PATCH 053/119] Update workflow templates to v0.8.0 (#11828) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6c1cd86d2..890070d5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.13 -comfyui-workflow-templates==0.7.69 +comfyui-workflow-templates==0.8.0 comfyui-embedded-docs==0.4.0 torch torchsde From ecaeeb990d7a5c3820b6f2373d04335d051d6b47 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 13 Jan 2026 11:18:01 +0800 Subject: [PATCH 054/119] chore: update workflow templates to v0.8.4 (#11835) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 890070d5d..077c8930a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.13 -comfyui-workflow-templates==0.8.0 +comfyui-workflow-templates==0.8.4 comfyui-embedded-docs==0.4.0 torch torchsde From b3c0e4de57bfd27e3dd94bd9723bb4c714668a09 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:33:54 -0800 Subject: [PATCH 055/119] Make loras work on nvfp4 models. (#11837) The initial applying is a bit slow but will probably be sped up in the future. --- comfy/float.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 2 +- comfy/quant_ops.py | 37 ++++++++++++++- requirements.txt | 2 +- 4 files changed, 150 insertions(+), 4 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..e638b1ff7 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + sign = torch.signbit(x).to(torch.uint8) + x_abs = x.abs() + + exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x_abs = x.abs() + exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x_abs * 2.0) + ).round().to(torch.uint8) + + fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(x.shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + max_abs = torch.amax(torch.abs(x), dim=-1) + block_scale = max_abs / F4_E2M1_MAX + scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) + scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) + + # Handle zero blocks (from padding): avoid 0/0 NaN + zero_scale_mask = (total_scale == 0) + total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) + + x = x / total_scale_safe.unsqueeze(-1) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) + + x = x.view(orig_shape) + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + + blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) + return data_lp, blocked_scales diff --git a/comfy/ops.py b/comfy/ops.py index 9c0b54ff4..415c39e92 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a..7a61203c3 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -7,7 +7,7 @@ try: QuantizedTensor, QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, - TensorCoreNVFP4Layout, # Direct import, no wrapper needed + TensorCoreNVFP4Layout as _CKNvfp4Layout, register_layout_op, register_layout_class, get_layout_class, @@ -34,7 +34,7 @@ except ImportError as e: class _CKFp8Layout: pass - class TensorCoreNVFP4Layout: + class _CKNvfp4Layout: pass def register_layout_class(name, cls): @@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreNVFP4Layout(_CKNvfp4Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) + + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params + + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn diff --git a/requirements.txt b/requirements.txt index 077c8930a..43737056e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.5 +comfy-kitchen>=0.2.6 #non essential dependencies: kornia>=0.7.1 From 117e7a5853cc34b6a012e06bb3efcc79ab314184 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 21:01:52 -0800 Subject: [PATCH 056/119] Refactor to try to lower mem usage. (#11840) --- comfy/float.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index e638b1ff7..c806af76b 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -69,26 +69,31 @@ def stochastic_rounding(value, dtype, seed=0): # TODO: improve this? def stochastic_float_to_fp4_e2m1(x, generator): + orig_shape = x.shape sign = torch.signbit(x).to(torch.uint8) - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + x = x.abs() + exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) mantissa = torch.where( exp > 0, - (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, - (x_abs * 2.0) + (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x * 2.0), + out=x ).round().to(torch.uint8) + del x - fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + exp = exp.to(torch.uint8) + + fp4 = (sign << 3) | (exp << 1) | mantissa + del sign, exp, mantissa fp4_flat = fp4.view(-1) packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] - return packed.reshape(list(x.shape)[:-1] + [-1]) + return packed.reshape(list(orig_shape)[:-1] + [-1]) def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: From acd0e536533cdf038bbaa32730cd12a438cc3a60 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 12 Jan 2026 21:15:24 -0800 Subject: [PATCH 057/119] Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) --- app/assets/database/bulk_ops.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py index 9352cd65d..c7b75290a 100644 --- a/app/assets/database/bulk_ops.py +++ b/app/assets/database/bulk_ops.py @@ -92,14 +92,23 @@ def seed_from_paths_batch( session.execute(ins_asset, chunk) # try to claim AssetCacheState (file_path) - winners_by_path: set[str] = set() + # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted ins_state = ( sqlite.insert(AssetCacheState) .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - .returning(AssetCacheState.file_path) ) for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - winners_by_path.update((session.execute(ins_state, chunk)).scalars().all()) + session.execute(ins_state, chunk) + + # Query to find which of our paths won (were actually inserted) + winners_by_path: set[str] = set() + for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetCacheState.file_path) + .where(AssetCacheState.file_path.in_(chunk)) + .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) + ) + winners_by_path.update(result.scalars().all()) all_paths_set = set(path_list) losers_by_path = all_paths_set - winners_by_path @@ -112,16 +121,23 @@ def seed_from_paths_batch( return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} # insert AssetInfo only for winners + # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] ins_info = ( sqlite.insert(AssetInfo) .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - .returning(AssetInfo.id) ) - - inserted_info_ids: set[str] = set() for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all()) + session.execute(ins_info, chunk) + + # Query to find which info rows were actually inserted (by matching our generated IDs) + all_info_ids = [row["id"] for row in winner_info_rows] + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) + ) + inserted_info_ids.update(result.scalars().all()) # build and insert tag + meta rows for the AssetInfo tag_rows: list[dict] = [] From 8af13b439bddaddb6d3b4f7c50b6391e88a10c66 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 12 Jan 2026 22:22:25 -0800 Subject: [PATCH 058/119] Update requirements.txt (#11841) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 43737056e..8650d28ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.36.13 +comfyui-frontend-package==1.36.14 comfyui-workflow-templates==0.8.4 comfyui-embedded-docs==0.4.0 torch From db9e6edfa1604be0b1f738b2f67495b46cee5a8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jan 2026 01:23:31 -0500 Subject: [PATCH 059/119] ComfyUI v0.9.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index df82ed4fc..09def2c70 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.2" +__version__ = "0.9.0" diff --git a/pyproject.toml b/pyproject.toml index 49f1a03fd..17aac8c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.2" +version = "0.9.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 1dcbd9efaf16c74a3aff2770f46de5b4aaaf927e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 22:42:07 -0800 Subject: [PATCH 060/119] Bump ltxav mem estimation a bit. (#11842) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d44c0bc37..1bf54f13f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -845,7 +845,7 @@ class LTXAV(LTXV): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = 0.061 # TODO + self.memory_usage_factor = 0.077 # TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.LTXAV(self, device=device) From 5ac13725331c1dfdf7aab977d4588b0a06a3debd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jan 2026 01:44:06 -0500 Subject: [PATCH 061/119] ComfyUI v0.9.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 09def2c70..0c9871e35 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.0" +__version__ = "0.9.1" diff --git a/pyproject.toml b/pyproject.toml index 17aac8c3f..dc52218b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.0" +version = "0.9.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From c543ad81c382c8450d2c8de62644c197c3c7416d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:30:13 +0200 Subject: [PATCH 062/119] fix(api-nodes-gemini): raise exception when no candidates due to safety block (#11848) --- comfy_api_nodes/nodes_gemini.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index e8ed7e797..35bbf0d2f 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -130,7 +130,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera Returns: List of response parts matching the requested type. """ - if response.candidates is None: + if not response.candidates: if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( @@ -141,14 +141,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] - for part in response.candidates[0].content.parts: - if part_type == "text" and part.text: - parts.append(part) - elif part.inlineData and part.inlineData.mimeType == part_type: - parts.append(part) - elif part.fileData and part.fileData.mimeType == part_type: - parts.append(part) - # Skip parts that don't match the requested type + blocked_reasons = [] + for candidate in response.candidates: + if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT": + blocked_reasons.append(candidate.finishReason) + continue + if candidate.content is None or candidate.content.parts is None: + continue + for part in candidate.content.parts: + if part_type == "text" and part.text: + parts.append(part) + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: + parts.append(part) + + if not parts and blocked_reasons: + raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}") + return parts From d9dc02a7d602a1918b9dabfc91890e6689f6f16d Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 13 Jan 2026 21:03:53 +0100 Subject: [PATCH 063/119] Support "lite" version of alibaba-pai Z-Image Controlnet (#11849) * reduced number of control layers (3) compared to full model --- comfy_extras/nodes_model_patch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 1355b3c93..f66d28fc9 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -244,6 +244,10 @@ class ModelPatchLoader: elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) config = {} + if 'control_layers.4.adaLN_modulation.0.weight' not in sd: + config['n_control_layers'] = 3 + config['additional_in_dim'] = 17 + config['refiner_control'] = True if 'control_layers.14.adaLN_modulation.0.weight' in sd: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 From e4b4fb34798a4710f670c81ae905ec24d58b6373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 14 Jan 2026 00:37:21 +0200 Subject: [PATCH 064/119] Load metadata on VAELoader (#11846) Needed to load the proper LTX2 VAE if separated from checkpoint --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 5a9d42d4a..90c5f2a6e 100644 --- a/nodes.py +++ b/nodes.py @@ -798,8 +798,8 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) From 79f6bb5e4fca0c2fbd5d09511a65449ca69332a8 Mon Sep 17 00:00:00 2001 From: ric-yu Date: Tue, 13 Jan 2026 16:14:40 -0800 Subject: [PATCH 065/119] add blueprints dir for built-in blueprints (#11853) --- app/subgraph_manager.py | 80 +++++++++++++++++++++------------- blueprints/put_blueprints_here | 0 2 files changed, 50 insertions(+), 30 deletions(-) create mode 100644 blueprints/put_blueprints_here diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py index dbe404541..6a8f586a4 100644 --- a/app/subgraph_manager.py +++ b/app/subgraph_manager.py @@ -10,6 +10,7 @@ import hashlib class Source: custom_node = "custom_node" + templates = "templates" class SubgraphEntry(TypedDict): source: str @@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict): class SubgraphManager: def __init__(self): self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None + + def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]: + """Create a subgraph entry from a file path. Expects normalized path (forward slashes).""" + entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": source, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": {"node_pack": node_pack}, + } + return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): with open(entry['path'], 'r') as f: @@ -60,53 +73,60 @@ class SubgraphManager: return entries async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): - # if not forced to reload and cached, return cache + """Load subgraphs from custom nodes.""" if not force_reload and self.cached_custom_node_subgraphs is not None: return self.cached_custom_node_subgraphs - # Load subgraphs from custom nodes - subfolder = "subgraphs" - subgraphs_dict: dict[SubgraphEntry] = {} + subgraphs_dict: dict[SubgraphEntry] = {} for folder in folder_paths.get_folder_paths("custom_nodes"): - pattern = os.path.join(folder, f"*/{subfolder}/*.json") - matched_files = glob.glob(pattern) - for file in matched_files: - # replace backslashes with forward slashes + pattern = os.path.join(folder, "*/subgraphs/*.json") + for file in glob.glob(pattern): file = file.replace('\\', '/') - info: CustomNodeSubgraphEntryInfo = { - "node_pack": "custom_nodes." + file.split('/')[-3] - } - source = Source.custom_node - # hash source + path to make sure id will be as unique as possible, but - # reproducible across backend reloads - id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() - entry: SubgraphEntry = { - "source": Source.custom_node, - "name": os.path.splitext(os.path.basename(file))[0], - "path": file, - "info": info, - } - subgraphs_dict[id] = entry + node_pack = "custom_nodes." + file.split('/')[-3] + entry_id, entry = self._create_entry(file, Source.custom_node, node_pack) + subgraphs_dict[entry_id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict return subgraphs_dict - async def get_custom_node_subgraph(self, id: str, loadedModules): - subgraphs = await self.get_custom_node_subgraphs(loadedModules) - entry: SubgraphEntry = subgraphs.get(id, None) - if entry is not None and entry.get('data', None) is None: + async def get_blueprint_subgraphs(self, force_reload=False): + """Load subgraphs from the blueprints directory.""" + if not force_reload and self.cached_blueprint_subgraphs is not None: + return self.cached_blueprint_subgraphs + + subgraphs_dict: dict[SubgraphEntry] = {} + blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints') + + if os.path.exists(blueprints_dir): + for file in glob.glob(os.path.join(blueprints_dir, "*.json")): + file = file.replace('\\', '/') + entry_id, entry = self._create_entry(file, Source.templates, "comfyui") + subgraphs_dict[entry_id] = entry + + self.cached_blueprint_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_all_subgraphs(self, loadedModules, force_reload=False): + """Get all subgraphs from all sources (custom nodes and blueprints).""" + custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload) + blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload) + return {**custom_node_subgraphs, **blueprint_subgraphs} + + async def get_subgraph(self, id: str, loadedModules): + """Get a specific subgraph by ID from any source.""" + entry = (await self.get_all_subgraphs(loadedModules)).get(id) + if entry is not None and entry.get('data') is None: await self.load_entry_data(entry) return entry def add_routes(self, routes, loadedModules): @routes.get("/global_subgraphs") async def get_global_subgraphs(request): - subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) - # NOTE: we may want to include other sources of global subgraphs such as templates in the future; - # that's the reasoning for the current implementation + subgraphs_dict = await self.get_all_subgraphs(loadedModules) return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) @routes.get("/global_subgraphs/{id}") async def get_global_subgraph(request): id = request.match_info.get("id", None) - subgraph = await self.get_custom_node_subgraph(id, loadedModules) + subgraph = await self.get_subgraph(id, loadedModules) return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/blueprints/put_blueprints_here b/blueprints/put_blueprints_here new file mode 100644 index 000000000..e69de29bb From 1419047fdbdf26b2311950c041a86fd998a2acbd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 14 Jan 2026 02:18:28 +0200 Subject: [PATCH 066/119] [Api Nodes]: Improve Price Badge Declarations (#11582) * api nodes: price badges moved to nodes code * added price badges for 4 more node-packs * added price badges for 10 more node-packs * added new price badges for Omni STD mode * add support for autogrow groups * use full names for "widgets", "inputs" and "groups" * add strict typing for JSONata rules * add price badge for WanReferenceVideoApi node * add support for DynamicCombo * sync price badges changes (https://github.com/Comfy-Org/ComfyUI_frontend/pull/7900) * sync badges for Vidu2 nodes * fixed incorrect price for RecraftCrispUpscaleNode * fixed incorrect price badges for LTXV nodes * fixed price badge for MinimaxHailuoVideoNode * fixed price badges for PixVerse nodes --- comfy_api/latest/_io.py | 79 +++++++++- comfy_api_nodes/nodes_bfl.py | 44 ++++++ comfy_api_nodes/nodes_bytedance.py | 66 +++++++++ comfy_api_nodes/nodes_gemini.py | 40 ++++++ comfy_api_nodes/nodes_ideogram.py | 42 +++++- comfy_api_nodes/nodes_kling.py | 215 ++++++++++++++++++++++++++++ comfy_api_nodes/nodes_ltxv.py | 18 +++ comfy_api_nodes/nodes_luma.py | 76 ++++++++++ comfy_api_nodes/nodes_minimax.py | 20 +++ comfy_api_nodes/nodes_moonvalley.py | 12 ++ comfy_api_nodes/nodes_openai.py | 127 ++++++++++++++++ comfy_api_nodes/nodes_pixverse.py | 30 ++++ comfy_api_nodes/nodes_recraft.py | 32 +++++ comfy_api_nodes/nodes_rodin.py | 12 ++ comfy_api_nodes/nodes_runway.py | 15 ++ comfy_api_nodes/nodes_sora.py | 18 +++ comfy_api_nodes/nodes_stability.py | 31 ++++ comfy_api_nodes/nodes_tripo.py | 164 +++++++++++++++++++++ comfy_api_nodes/nodes_veo2.py | 42 ++++++ comfy_api_nodes/nodes_vidu.py | 100 +++++++++++++ comfy_api_nodes/nodes_wan.py | 43 ++++++ 21 files changed, 1221 insertions(+), 5 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 50143ff53..e6a0d1821 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1225,6 +1225,7 @@ class NodeInfoV1: deprecated: bool=None experimental: bool=None api_node: bool=None + price_badge: dict | None = None @dataclass class NodeInfoV3: @@ -1234,11 +1235,77 @@ class NodeInfoV3: name: str=None display_name: str=None description: str=None + python_module: Any = None category: str=None output_node: bool=None deprecated: bool=None experimental: bool=None api_node: bool=None + price_badge: dict | None = None + + +@dataclass +class PriceBadgeDepends: + widgets: list[str] = field(default_factory=list) + inputs: list[str] = field(default_factory=list) + input_groups: list[str] = field(default_factory=list) + + def validate(self) -> None: + if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets): + raise ValueError("PriceBadgeDepends.widgets must be a list[str].") + if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs): + raise ValueError("PriceBadgeDepends.inputs must be a list[str].") + if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups): + raise ValueError("PriceBadgeDepends.input_groups must be a list[str].") + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + # Build lookup: widget_id -> io_type + input_types: dict[str, str] = {} + for inp in schema_inputs: + all_inputs = inp.get_all() + input_types[inp.id] = inp.get_io_type() # First input is always the parent itself + for nested_inp in all_inputs[1:]: + # For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID + # to match frontend naming convention (e.g., "should_texture.enable_pbr") + prefixed_id = f"{inp.id}.{nested_inp.id}" + input_types[prefixed_id] = nested_inp.get_io_type() + + # Enrich widgets with type information, raising error for unknown widgets + widgets_data: list[dict[str, str]] = [] + for w in self.widgets: + if w not in input_types: + raise ValueError( + f"PriceBadge depends_on.widgets references unknown widget '{w}'. " + f"Available widgets: {list(input_types.keys())}" + ) + widgets_data.append({"name": w, "type": input_types[w]}) + + return { + "widgets": widgets_data, + "inputs": self.inputs, + "input_groups": self.input_groups, + } + + +@dataclass +class PriceBadge: + expr: str + depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends) + engine: str = field(default="jsonata") + + def validate(self) -> None: + if self.engine != "jsonata": + raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.") + if not isinstance(self.expr, str) or not self.expr.strip(): + raise ValueError("PriceBadge.expr must be a non-empty string.") + self.depends_on.validate() + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + return { + "engine": self.engine, + "depends_on": self.depends_on.as_dict(schema_inputs), + "expr": self.expr, + } @dataclass @@ -1284,6 +1351,8 @@ class Schema: """Flags a node as experimental, informing users that it may change or not work as expected.""" is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + price_badge: PriceBadge | None = None + """Optional client-evaluated pricing badge declaration for this node.""" not_idempotent: bool=False """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False @@ -1314,6 +1383,8 @@ class Schema: input.validate() for output in self.outputs: output.validate() + if self.price_badge is not None: + self.price_badge.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1387,7 +1458,8 @@ class Schema: deprecated=self.is_deprecated, experimental=self.is_experimental, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, ) return info @@ -1419,7 +1491,8 @@ class Schema: deprecated=self.is_deprecated, experimental=self.is_experimental, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, ) return info @@ -1971,4 +2044,6 @@ __all__ = [ "add_to_dict_v3", "V3Data", "ImageCompare", + "PriceBadgeDepends", + "PriceBadge", ] diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index ce077d6b3..76021ef7f 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.06}""", + ), ) @classmethod @@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" DISPLAY_NAME = "Flux.2 [pro] Image" API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.03 + 0.015 * ($outMP - 1); + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.015, + "max_usd": $outputCost + 0.12, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ @classmethod def define_schema(cls) -> IO.Schema: @@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]), + expr=cls.PRICE_BADGE_EXPR, + ), ) @classmethod @@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode): NODE_ID = "Flux2MaxImageNode" DISPLAY_NAME = "Flux.2 [max] Image" API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.07 + 0.03 * ($outMP - 1); + + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.03, + "max_usd": $outputCost + 0.24, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ class BFLExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index d4a2cfae6..f09a4a0ed 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -126,6 +126,9 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -367,6 +370,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03; + { + "type":"usd", + "usd": $price, + "format": { "suffix":" x images/Run", "approximate": true } + } + ) + """, + ), ) @classmethod @@ -522,6 +538,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -632,6 +649,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -754,6 +772,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -877,6 +896,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -946,6 +966,52 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: ) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56], + "1080p":[1.18,1.22] + }, + "seedance-1-0-pro-fast": { + "480p":[0.09,0.1], + "720p":[0.21,0.23], + "1080p":[0.47,0.49] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41], + "1080p":[0.85,0.88] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "1080") ? "1080p" : + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, +) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 35bbf0d2f..a2daea50a 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -319,6 +319,30 @@ class GeminiNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "gemini-2.5-flash") ? { + "type": "list_usd", + "usd": [0.0003, 0.0025], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"} + } + : $contains($m, "gemini-2.5-pro") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gemini-3-pro-preview") ? { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type":"text", "text":"Token-based"} + ) + """, + ), ) @classmethod @@ -580,6 +604,9 @@ class GeminiImage(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""", + ), ) @classmethod @@ -710,6 +737,19 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $r := widgets.resolution; + ($contains($r,"1k") or $contains($r,"2k")) + ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} + : $contains($r,"4k") + ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} + : {"type":"text","text":"Token-based"} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 48f94e612..827b3523a 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0286 : 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0715 : 0.1144; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]), + expr=""" + ( + $n := widgets.num_images; + $speed := widgets.rendering_speed; + $hasChar := inputs.character_image.connected; + $base := + $contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) : + $contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) : + $contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) : + 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 01d9c34f5..05dde88b1 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -764,6 +764,33 @@ class KlingTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -818,6 +845,16 @@ class OmniProTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -886,6 +923,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -981,6 +1028,16 @@ class OmniProImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1056,6 +1113,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1142,6 +1209,16 @@ class OmniProEditVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod @@ -1228,6 +1305,9 @@ class OmniProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.028}""", + ), ) @classmethod @@ -1313,6 +1393,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14}""", + ), ) @classmethod @@ -1375,6 +1458,33 @@ class KlingImage2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + $contains($model,"v2-5-turbo") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : ($contains($model,"v2-1-master") or $contains($model,"v2-master")) + ? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1448,6 +1558,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.49}""", + ), ) @classmethod @@ -1518,6 +1631,33 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1583,6 +1723,9 @@ class KlingVideoExtendNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.28}""", + ), ) @classmethod @@ -1664,6 +1807,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + ($contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1728,6 +1894,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]), + expr=""" + ( + ($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom")) + ? {"type":"usd","usd":0.49} + : {"type":"usd","usd":0.28} + ) + """, + ), ) @classmethod @@ -1782,6 +1958,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1842,6 +2021,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1892,6 +2074,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.7}""", + ), ) @classmethod @@ -1991,6 +2176,19 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]), + expr=""" + ( + $m := widgets.model_name; + $base := + $contains($m,"kling-v1-5") + ? (inputs.image.connected ? 0.028 : 0.014) + : ($contains($m,"kling-v1") ? 0.0035 : 0.014); + {"type":"usd","usd": $base * widgets.n} + ) + """, + ), ) @classmethod @@ -2074,6 +2272,10 @@ class TextToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2138,6 +2340,10 @@ class ImageToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2218,6 +2424,15 @@ class MotionControl(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $prices := {"std": 0.07, "pro": 0.112}; + {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 7e61560dc..c6424af92 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel): image_uri: str | None = Field(None) +PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $prices := { + "ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24}, + "ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16} + }; + $modelPrices := $lookup($prices, $lowercase(widgets.model)); + $pps := $lookup($modelPrices, widgets.resolution); + {"type":"usd","usd": $pps * widgets.duration} + ) + """, +) + + class TextToVideoNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod @@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 894f2b08c..95cb442e5 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, + ) @classmethod @@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return LumaKeyframes(frame0=frame0, frame1=frame1) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]), + expr=""" + ( + $p := { + "ray-flash-2": { + "5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2}, + "9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36} + }, + "ray-2": { + "5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57}, + "9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03} + } + }; + + $m := widgets.model; + $d := widgets.duration; + $r := widgets.resolution; + + $modelKey := + $contains($m,"ray-flash-2") ? "ray-flash-2" : + $contains($m,"ray-2") ? "ray-2" : + $contains($m,"ray-1-6") ? "ray-1-6" : + "other"; + + $durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : ""; + $resKey := + $contains($r,"4k") ? "4k" : + $contains($r,"1080p") ? "1080p" : + $contains($r,"720p") ? "720p" : + $contains($r,"540p") ? "540p" : ""; + + $modelPrices := $lookup($p, $modelKey); + $durPrices := $lookup($modelPrices, $durKey); + $v := $lookup($durPrices, $resKey); + + $price := + ($modelKey = "ray-1-6") ? 0.5 : + ($modelKey = "other") ? 0.79 : + ($exists($v) ? $v : 0.79); + + {"type":"usd","usd": $price} + ) + """, +) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 05cbb700f..43a15d50d 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $prices := { + "768p": {"6": 0.28, "10": 0.56}, + "1080p": {"6": 0.49} + }; + $resPrices := $lookup($prices, $lowercase(widgets.resolution)); + $price := $lookup($resPrices, $string(widgets.duration)); + {"type":"usd","usd": $price ? $price : 0.43} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 2771e4790..769b171b7 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod @@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 2.25}""", + ), ) @classmethod @@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a6205a34f..2f144c5c3 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -160,6 +160,23 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]), + expr=""" + ( + $size := widgets.size; + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; + + $base := + $contains($size, "256x256") ? 0.016 : + $contains($size, "512x512") ? 0.018 : + 0.02; + + {"type":"usd","usd": $round($base * $n, 3)} + ) + """, + ), ) @classmethod @@ -287,6 +304,25 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]), + expr=""" + ( + $size := widgets.size; + $q := widgets.quality; + $hd := $contains($q, "hd"); + + $price := + $contains($size, "1024x1024") + ? ($hd ? 0.08 : 0.04) + : (($contains($size, "1792x1024") or $contains($size, "1024x1792")) + ? ($hd ? 0.12 : 0.08) + : 0.04); + + {"type":"usd","usd": $price} + ) + """, + ), ) @classmethod @@ -411,6 +447,28 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + expr=""" + ( + $ranges := { + "low": [0.011, 0.02], + "medium": [0.046, 0.07], + "high": [0.167, 0.3] + }; + $range := $lookup($ranges, widgets.quality); + $n := widgets.n; + ($n = 1) + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + : { + "type":"range_usd", + "min_usd": $range[0], + "max_usd": $range[1], + "format": { "suffix": " x " & $string($n) & "/Run" } + } + ) + """, + ), ) @classmethod @@ -566,6 +624,75 @@ class OpenAIChatNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "o4-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1-pro") ? { + "type": "list_usd", + "usd": [0.15, 0.6], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1") ? { + "type": "list_usd", + "usd": [0.015, 0.06], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3") ? { + "type": "list_usd", + "usd": [0.01, 0.04], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4o") ? { + "type": "list_usd", + "usd": [0.0025, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-nano") ? { + "type": "list_usd", + "usd": [0.0001, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-mini") ? { + "type": "list_usd", + "usd": [0.0004, 0.0016], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1") ? { + "type": "list_usd", + "usd": [0.002, 0.008], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-nano") ? { + "type": "list_usd", + "usd": [0.00005, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-mini") ? { + "type": "list_usd", + "usd": [0.00025, 0.002], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type": "text", "text": "Token-based"} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 6e1686af0..86ddb3ab9 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]), + expr=""" + ( + $prices := { + "5": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 0.6, "fast": 1.2}, + "540p": {"normal": 0.45, "fast": 0.9}, + "360p": {"normal": 0.45, "fast": 0.9} + }, + "8": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 1.2, "fast": 1.2}, + "540p": {"normal": 0.9, "fast": 1.2}, + "360p": {"normal": 0.9, "fast": 1.2} + } + }; + $durPrices := $lookup($prices, $string(widgets.duration_seconds)); + $qualityPrices := $lookup($durPrices, widgets.quality); + $price := $lookup($qualityPrices, widgets.motion_mode); + {"type":"usd","usd": $price ? $price : 0.9} + ) + """, +) + + class PixVerseExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e3440b946..05dc151ad 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -378,6 +378,10 @@ class RecraftTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -490,6 +494,10 @@ class RecraftImageToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -591,6 +599,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -692,6 +704,10 @@ class RecraftTextToVectorNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""", + ), ) @classmethod @@ -759,6 +775,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 0.01}""", + ), ) @classmethod @@ -817,6 +837,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), ) @classmethod @@ -883,6 +906,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -929,6 +955,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.004}""", + ), ) @classmethod @@ -972,6 +1001,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e60e7a6d6..b4420cb93 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 3c55039c9..d19fdb365 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 92b225d40..87e663845 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]), + expr=""" + ( + $m := widgets.model; + $size := widgets.size; + $dur := widgets.duration; + $isPro := $contains($m, "sora-2-pro"); + $isSora2 := $contains($m, "sora-2"); + $isProSize := ($size = "1024x1792" or $size = "1792x1024"); + $perSec := + $isPro ? ($isProSize ? 0.5 : 0.3) : + $isSora2 ? 0.1 : + ($isProSize ? 0.5 : 0.1); + {"type":"usd","usd": $round($perSec * $dur, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index bb7ceed78..5c48c1f1e 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.08}""", + ), ) @classmethod @@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model,"large") + ? {"type":"usd","usd":0.065} + : {"type":"usd","usd":0.035} + ) + """, + ), ) @classmethod @@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index e72f8e96a..aa790143d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 20 : ($withTexture ? 20 : 10); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]), + expr=""" + ( + $tq := widgets.texture_quality; + {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} + ) + """, + ), ) @classmethod @@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.3}""", + ), ) @classmethod @@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1}""", + ), ) @classmethod @@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "quad", + "face_limit", + "texture_size", + "texture_format", + "force_symmetry", + "flatten_bottom", + "flatten_bottom_threshold", + "pivot_to_center_bottom", + "scale_factor", + "with_animation", + "pack_uv", + "bake", + "part_names", + "fbx_preset", + "export_vertex_colors", + "export_orientation", + "animate_in_place", + ], + ), + expr=""" + ( + $face := (widgets.face_limit != null) ? widgets.face_limit : -1; + $texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096; + $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; + $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; + $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); + $part := widgets.part_names; + $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); + $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); + $advanced := + widgets.quad or + widgets.force_symmetry or + widgets.flatten_bottom or + widgets.pivot_to_center_bottom or + widgets.with_animation or + widgets.pack_uv or + widgets.bake or + widgets.export_vertex_colors or + widgets.animate_in_place or + ($face != -1) or + ($texSize != 4096) or + ($flatThresh != 0) or + ($scale != 1) or + ($texFmt != "jpeg") or + ($part != "") or + ($fbx != "blender") or + ($orient != "default"); + {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13a6bfd91..c14d6ad68 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]), + expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""", + ), ) @classmethod @@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + expr=""" + ( + $m := widgets.model; + $a := widgets.generate_audio; + ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) + ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} + : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) + ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} + : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + ) + """, + ), ) @@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + expr=""" + ( + $prices := { + "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, + "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } + }; + $m := widgets.model; + $ga := (widgets.generate_audio = "true"); + $seconds := widgets.duration; + $modelKey := + $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : + $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : + ""; + $audioKey := $ga ? "audio" : "no_audio"; + $modelPrices := $lookup($prices, $modelKey); + $pps := $lookup($modelPrices, $audioKey); + ($pps != null) + ? {"type":"usd","usd": $pps * $seconds} + : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 9d94ae7ad..8edb02f39 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -121,6 +121,9 @@ class ViduTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -214,6 +217,9 @@ class ViduImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -317,6 +323,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -426,6 +435,9 @@ class ViduStartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -507,6 +519,17 @@ class Vidu2TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.1 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1)} + ) + """, + ), ) @classmethod @@ -594,6 +617,39 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 1 ? {"type":"usd","usd": 0.04} + : $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), ) @classmethod @@ -698,6 +754,18 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.375 : 0.125; + $perSec := $is1080 ? 0.05 : 0.025; + $audioCost := widgets.audio = true ? 0.075 : 0; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost} + ) + """, + ), ) @classmethod @@ -804,6 +872,38 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 3e04786a9..a1355d4f1 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -244,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -363,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -520,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $resKey := $substringBefore(widgets.size, ":"); + $pps := $lookup($ppsTable, $resKey); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -681,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $pps := $lookup($ppsTable, widgets.resolution); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -828,6 +855,22 @@ class WanReferenceVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]), + expr=""" + ( + $rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10; + $inputMin := 2 * $rate; + $inputMax := 5 * $rate; + $outputPrice := widgets.duration * $rate; + { + "type": "range_usd", + "min_usd": $inputMin + $outputPrice, + "max_usd": $inputMax + $outputPrice + } + ) + """, + ), ) @classmethod From 15b312de7a74a836fa45b989a7697895b01e0cbf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:23:58 -0800 Subject: [PATCH 067/119] Optimize nvfp4 lora applying. (#11854) --- comfy/float.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index c806af76b..1a6070bff 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -165,20 +165,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_scale = max_abs / F4_E2M1_MAX scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) - - # Handle zero blocks (from padding): avoid 0/0 NaN - zero_scale_mask = (total_scale == 0) - total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) - - x = x / total_scale_safe.unsqueeze(-1) + x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) - - x = x.view(orig_shape) + x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) From eff2b9d412932aa7d49e6302cdf6e7cf24808b6f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:37:19 -0800 Subject: [PATCH 068/119] Optimize nvfp4 lora applying. (#11856) --- comfy/float.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 1a6070bff..8c303bea0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -161,10 +161,7 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_size = 16 x = x.reshape(orig_shape[0], -1, block_size) - max_abs = torch.amax(torch.abs(x), dim=-1) - block_scale = max_abs / F4_E2M1_MAX - scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) - scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) @@ -172,6 +169,5 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) return data_lp, blocked_scales From 469dd9c16ad88765ffe4e7bfa57dd80faafbaddf Mon Sep 17 00:00:00 2001 From: nomadoor <124905471+nomadoor@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:48:10 +0900 Subject: [PATCH 069/119] Adds crop to multiple mode to ResizeImageMaskNode. (#11838) * Add crop-to-multiple resize mode * Make scale-to-multiple shape handling explicit --- comfy_extras/nodes_post_processing.py | 44 +++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 01afa13a1..0433bbda2 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -254,6 +254,7 @@ class ResizeType(str, Enum): SCALE_HEIGHT = "scale height" SCALE_TOTAL_PIXELS = "scale total pixels" MATCH_SIZE = "match size" + SCALE_TO_MULTIPLE = "scale to multiple" def is_image(input: torch.Tensor) -> bool: # images have 4 dimensions: [batch, height, width, channels] @@ -363,6 +364,43 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str input = finalize_image_mask_input(input, is_type_image) return input +def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor: + if multiple <= 1: + return input + is_type_image = is_image(input) + if is_type_image: + _, height, width, _ = input.shape + else: + _, height, width = input.shape + target_w = (width // multiple) * multiple + target_h = (height // multiple) * multiple + if target_w == 0 or target_h == 0: + return input + if target_w == width and target_h == height: + return input + s_w = target_w / width + s_h = target_h / height + if s_w >= s_h: + scaled_w = target_w + scaled_h = int(math.ceil(height * s_w)) + if scaled_h < target_h: + scaled_h = target_h + else: + scaled_h = target_h + scaled_w = int(math.ceil(width * s_h)) + if scaled_w < target_w: + scaled_w = target_w + input = init_image_mask_input(input, is_type_image) + input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + x0 = (scaled_w - target_w) // 2 + y0 = (scaled_h - target_h) // 2 + x1 = x0 + target_w + y1 = y0 + target_h + if is_type_image: + return input[:, y0:y1, x0:x1, :] + return input[:, y0:y1, x0:x1] + class ResizeImageMaskNode(io.ComfyNode): scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @@ -378,6 +416,7 @@ class ResizeImageMaskNode(io.ComfyNode): longer_size: int shorter_size: int megapixels: float + multiple: int @classmethod def define_schema(cls): @@ -417,6 +456,9 @@ class ResizeImageMaskNode(io.ComfyNode): io.MultiType.Input("match", [io.Image, io.Mask]), crop_combo, ]), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), + ]), ]), io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), ], @@ -442,6 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode): return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) elif selected_type == ResizeType.MATCH_SIZE: return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_TO_MULTIPLE: + return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method)) raise ValueError(f"Unsupported resize type: {selected_type}") def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: From 7eb959ce934da914523b455f9a6e7e0662690325 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Tue, 13 Jan 2026 18:03:16 -0800 Subject: [PATCH 070/119] fix: update ComfyUI repo reference to Comfy-Org/ComfyUI (#11858) --- .github/workflows/test-launch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index ef0d3f123..581c0474b 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -13,7 +13,7 @@ jobs: - name: Checkout ComfyUI uses: actions/checkout@v4 with: - repository: "comfyanonymous/ComfyUI" + repository: "Comfy-Org/ComfyUI" path: "ComfyUI" - uses: actions/setup-python@v4 with: From c9196f355ef5832daf55c4bbe8c6279dec509331 Mon Sep 17 00:00:00 2001 From: nomadoor <124905471+nomadoor@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:25:09 +0900 Subject: [PATCH 071/119] Fix scale_shorter_dimension portrait check (#11862) --- comfy_extras/nodes_post_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 0433bbda2..2e559c35c 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -329,7 +329,7 @@ def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method if height < width: width = round((width / height) * shorter_size) height = shorter_size - elif width > height: + elif width < height: height = round((height / width) * shorter_size) width = shorter_size else: From ac4d8ea9b32f56410860dccdb30ae50a1029d6fd Mon Sep 17 00:00:00 2001 From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com> Date: Wed, 14 Jan 2026 04:39:22 +0100 Subject: [PATCH 072/119] feat: add CI container version bump automation (#11692) * feat: add CI container version bump automation Adds a workflow that triggers on releases to create PRs in the comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile. Supports both release events and manual workflow dispatch for testing. * feat: add CI container version bump automation Adds a workflow that triggers on releases to create PRs in the comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile. Supports both release events and manual workflow dispatch for testing. * ci: update CI container repository owner * refactor: rename `update-ci-container.yaml` workflow to `update-ci-container.yml` * Remove post-merge instructions from the CI container update workflow. --- .github/workflows/update-ci-container.yml | 59 +++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .github/workflows/update-ci-container.yml diff --git a/.github/workflows/update-ci-container.yml b/.github/workflows/update-ci-container.yml new file mode 100644 index 000000000..f7972e056 --- /dev/null +++ b/.github/workflows/update-ci-container.yml @@ -0,0 +1,59 @@ +name: "CI: Update CI Container" + +on: + release: + types: [published] + workflow_dispatch: + inputs: + version: + description: 'ComfyUI version (e.g., v0.7.0)' + required: true + type: string + +jobs: + update-ci-container: + runs-on: ubuntu-latest + # Skip pre-releases unless manually triggered + if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease + steps: + - name: Get version + id: version + run: | + if [ "${{ github.event_name }}" = "release" ]; then + VERSION="${{ github.event.release.tag_name }}" + else + VERSION="${{ inputs.version }}" + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Checkout comfyui-ci-container + uses: actions/checkout@v4 + with: + repository: comfy-org/comfyui-ci-container + token: ${{ secrets.CI_CONTAINER_PAT }} + + - name: Check current version + id: current + run: | + CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown") + echo "current_version=$CURRENT" >> $GITHUB_OUTPUT + + - name: Update Dockerfile + run: | + VERSION="${{ steps.version.outputs.version }}" + sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile + + - name: Create Pull Request + id: create-pr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.CI_CONTAINER_PAT }} + branch: automation/comfyui-${{ steps.version.outputs.version }} + title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" + body: | + Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}` + + **Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }} + + labels: automation + commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" From 712cca36a13db93a9fa1fde9b7b5f9a5b961209a Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Wed, 14 Jan 2026 04:41:44 +0100 Subject: [PATCH 073/119] feat: throttle ProgressBar updates to reduce WebSocket flooding (#11504) --- comfy/utils.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index ffa98c9b1..fac13f128 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -30,6 +30,7 @@ from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args import json +import time MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1097,6 +1098,10 @@ def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function +# Throttle settings for progress bar updates to reduce WebSocket flooding +PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates +PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change + class ProgressBar: def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK @@ -1104,6 +1109,8 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK self.node_id = node_id + self._last_update_time = 0.0 + self._last_sent_value = -1 def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1112,7 +1119,29 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview, node_id=self.node_id) + current_time = time.perf_counter() + is_first = (self._last_sent_value < 0) + is_final = (value >= self.total) + has_preview = (preview is not None) + + # Always send immediately for previews, first update, or final update + if has_preview or is_first or is_final: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value + return + + # Apply throttling for regular progress updates + if self.total > 0: + percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 + else: + percent_changed = 100 + time_elapsed = current_time - self._last_update_time + + if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value def update(self, value): self.update_absolute(self.current + value) From 6165c38cb58c40b15ade879b80051b6c9148587f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:49:38 -0800 Subject: [PATCH 074/119] Optimize nvfp4 lora applying. (#11866) This changes results a bit but it also speeds up things a lot. --- comfy/float.py | 56 ++++++++++++++++++++++++++++++++------- comfy/quant_ops.py | 2 +- comfy/supported_models.py | 2 +- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 8c303bea0..88c47cd80 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: return rearranged.reshape(padded_rows, padded_cols) -def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): +def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): F4_E2M1_MAX = 6.0 F8_E4M3_MAX = 448.0 + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + + x = x.view(orig_shape).nan_to_num() + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + return data_lp, scaled_block_scales_fp8 + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) + return x, to_blocked(blocked_scaled, flatten=False) + + +def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): def roundup(x: int, multiple: int) -> int: """Round up x to the nearest multiple.""" return ((x + multiple - 1) // multiple) * multiple @@ -158,16 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): # what we want to produce. If we pad here, we want the padded output. orig_shape = x.shape - block_size = 16 + orig_shape = list(orig_shape) - x = x.reshape(orig_shape[0], -1, block_size) - scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) + output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = x.view(orig_shape).nan_to_num() - data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) - return data_lp, blocked_scales + num_slices = max(1, (x.numel() / block_size)) + slice_size = max(1, (round(x.shape[0] / num_slices))) + + for i in range(0, x.shape[0], slice_size): + fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) + output_fp4[i:i + slice_size].copy_(fp4) + output_block[i:i + slice_size].copy_(block) + + return output_fp4, to_blocked(output_block, flatten=False) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 7a61203c3..15a4f457b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -104,7 +104,7 @@ class TensorCoreNVFP4Layout(_CKNvfp4Layout): needs_padding = padded_shape != orig_shape if stochastic_rounding > 0: - qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) else: qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1bf54f13f..2c4c6b8fc 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1042,7 +1042,7 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 2.0 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32] From d1504404662dfce6e401422701c2a7e24057b1b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:54:50 -0800 Subject: [PATCH 075/119] Fix VAELoader (#11880) --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 90c5f2a6e..aa8572446 100644 --- a/nodes.py +++ b/nodes.py @@ -788,6 +788,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): + metadata = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) From 07f2462eae7fa2daa34971dd1b15fd525686e958 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 14 Jan 2026 21:25:38 +0200 Subject: [PATCH 076/119] feat(api-nodes): add Meshy 3D nodes (#11843) * feat(api-nodes): add Meshy 3D nodes * rebased, added JSONata price badges --- comfy_api_nodes/apis/meshy.py | 160 +++++ comfy_api_nodes/nodes_meshy.py | 790 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 23 +- nodes.py | 1 + 4 files changed, 969 insertions(+), 5 deletions(-) create mode 100644 comfy_api_nodes/apis/meshy.py create mode 100644 comfy_api_nodes/nodes_meshy.py diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py new file mode 100644 index 000000000..be46d0d58 --- /dev/null +++ b/comfy_api_nodes/apis/meshy.py @@ -0,0 +1,160 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + +from comfy_api.latest import Input + + +class InputShouldRemesh(TypedDict): + should_remesh: str + topology: str + target_polycount: int + + +class InputShouldTexture(TypedDict): + should_texture: str + enable_pbr: bool + texture_prompt: str + texture_image: Input.Image | None + + +class MeshyTaskResponse(BaseModel): + result: str = Field(...) + + +class MeshyTextToModelRequest(BaseModel): + mode: str = Field("preview") + prompt: str = Field(..., max_length=600) + art_style: str = Field(..., description="'realistic' or 'sculpture'") + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + pose_mode: str = Field(...) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRefineTask(BaseModel): + mode: str = Field("refine") + preview_task_id: str = Field(...) + enable_pbr: bool | None = Field(...) + texture_prompt: str | None = Field(...) + texture_image_url: str | None = Field(...) + ai_model: str = Field(...) + moderation: bool = Field(False) + + +class MeshyImageToModelRequest(BaseModel): + image_url: str = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyMultiImageToModelRequest(BaseModel): + image_urls: list[str] = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRiggingRequest(BaseModel): + input_task_id: str = Field(...) + height_meters: float = Field(...) + texture_image_url: str | None = Field(...) + + +class MeshyAnimationRequest(BaseModel): + rig_task_id: str = Field(...) + action_id: int = Field(...) + + +class MeshyTextureRequest(BaseModel): + input_task_id: str = Field(...) + ai_model: str = Field(...) + enable_original_uv: bool = Field(...) + enable_pbr: bool = Field(...) + text_style_prompt: str | None = Field(...) + image_style_url: str | None = Field(...) + + +class MeshyModelsUrls(BaseModel): + glb: str = Field("") + + +class MeshyRiggedModelsUrls(BaseModel): + rigged_character_glb_url: str = Field("") + + +class MeshyAnimatedModelsUrls(BaseModel): + animation_glb_url: str = Field("") + + +class MeshyResultTextureUrls(BaseModel): + base_color: str = Field(...) + metallic: str | None = Field(None) + normal: str | None = Field(None) + roughness: str | None = Field(None) + + +class MeshyTaskError(BaseModel): + message: str | None = Field(None) + + +class MeshyModelResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + model_urls: MeshyModelsUrls = Field(MeshyModelsUrls()) + thumbnail_url: str = Field(...) + video_url: str | None = Field(None) + status: str = Field(...) + progress: int = Field(0) + texture_urls: list[MeshyResultTextureUrls] | None = Field([]) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyRiggedResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyRiggedModelsUrls = Field(MeshyRiggedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyAnimationResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyAnimatedModelsUrls = Field(MeshyAnimatedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py new file mode 100644 index 000000000..740607983 --- /dev/null +++ b/comfy_api_nodes/nodes_meshy.py @@ -0,0 +1,790 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.meshy import ( + InputShouldRemesh, + InputShouldTexture, + MeshyAnimationRequest, + MeshyAnimationResult, + MeshyImageToModelRequest, + MeshyModelResult, + MeshyMultiImageToModelRequest, + MeshyRefineTask, + MeshyRiggedResult, + MeshyRiggingRequest, + MeshyTaskResponse, + MeshyTextToModelRequest, + MeshyTextureRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_string, +) +from folder_paths import get_output_directory + + +class MeshyTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextToModelNode", + display_name="Meshy: Text to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("style", options=["realistic", "sculpture"]), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.8}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + style: str, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=600) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextToModelRequest( + prompt=prompt, + art_style=style, + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + should_remesh=should_remesh["should_remesh"] == "true", + symmetry_mode=symmetry_mode, + pose_mode=pose_mode.lower(), + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRefineNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRefineNode", + display_name="Meshy: Refine Draft Model", + category="api node/3d/Meshy", + description="Refine a previously created draft model.", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) in addition to the base color. " + "Note: this should be set to false when using Sculpture style, " + "as Sculpture style generates its own set of PBR maps.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' may be used at the same time.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_pbr: bool, + texture_prompt: str, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if texture_prompt and texture_image is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + texture_image_url = None + if texture_prompt: + validate_string(texture_prompt, field_name="texture_prompt", max_length=600) + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRefineTask( + preview_task_id=meshy_task_id, + enable_pbr=enable_pbr, + texture_prompt=texture_prompt if texture_prompt else None, + texture_image_url=texture_image_url, + ai_model=model, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyImageToModelNode", + display_name="Meshy: Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 1.2, "false": 0.8}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyImageToModelRequest( + image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0], + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyMultiImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyMultiImageToModelNode", + display_name="Meshy: Multi-Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=2, max=4), + ), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 0.6, "false": 0.2}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + images: IO.Autogrow.Type, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyMultiImageToModelRequest( + image_urls=await upload_images_to_comfyapi( + cls, list(images.values()), wait_label="Uploading base images" + ), + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRigModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRigModelNode", + display_name="Meshy: Rig Model", + category="api node/3d/Meshy", + description="Provides a rigged character in standard formats. " + "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " + "or humanoid assets with unclear limb and body structure.", + inputs=[ + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Float.Input( + "height_meters", + min=0.1, + max=15.0, + default=1.7, + tooltip="The approximate height of the character model in meters. " + "This aids in scaling and rigging accuracy.", + ), + IO.Image.Input( + "texture_image", + tooltip="The model's UV-unwrapped base color texture image.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), + ) + + @classmethod + async def execute( + cls, + meshy_task_id: str, + height_meters: float, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + texture_image_url = None + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRiggingRequest( + input_task_id=meshy_task_id, + height_meters=height_meters, + texture_image_url=texture_image_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"), + response_model=MeshyRiggedResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio( + result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file) + ) + return IO.NodeOutput(model_file, response.result) + + +class MeshyAnimateModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyAnimateModelNode", + display_name="Meshy: Animate Model", + category="api node/3d/Meshy", + description="Apply a specific animation action to a previously rigged character.", + inputs=[ + IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), + IO.Int.Input( + "action_id", + default=0, + min=0, + max=696, + tooltip="Visit https://docs.meshy.ai/en/api/animation-library for a list of available values.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.12}""", + ), + ) + + @classmethod + async def execute( + cls, + rig_task_id: str, + action_id: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyAnimationRequest( + rig_task_id=rig_task_id, + action_id=action_id, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"), + response_model=MeshyAnimationResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyTextureNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextureNode", + display_name="Meshy: Texture Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_original_uv", + default=True, + tooltip="Use the original UV of the model instead of generating new UVs. " + "When enabled, Meshy preserves existing textures from the uploaded model. " + "If the model has no original UV, the quality of the output might not be as good.", + ), + IO.Boolean.Input("pbr", default=False), + IO.String.Input( + "text_style_prompt", + default="", + multiline=True, + tooltip="Describe your desired texture style of the object using text. Maximum 600 characters." + "Maximum 600 characters. Cannot be used at the same time as 'image_style'.", + ), + IO.Image.Input( + "image_style", + optional=True, + tooltip="A 2d image to guide the texturing process. " + "Can not be used at the same time with 'text_style_prompt'.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_original_uv: bool, + pbr: bool, + text_style_prompt: str, + image_style: Input.Image | None = None, + ) -> IO.NodeOutput: + if text_style_prompt and image_style is not None: + raise ValueError("text_style_prompt and image_style cannot be used at the same time") + if not text_style_prompt and image_style is None: + raise ValueError("Either text_style_prompt or image_style is required") + image_style_url = None + if image_style is not None: + image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextureRequest( + input_task_id=meshy_task_id, + ai_model=model, + enable_original_uv=enable_original_uv, + enable_pbr=pbr, + text_style_prompt=text_style_prompt if text_style_prompt else None, + image_style_url=image_style_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MeshyTextToModelNode, + MeshyRefineNode, + MeshyImageToModelNode, + MeshyMultiImageToModelNode, + MeshyRigModelNode, + MeshyAnimateModelNode, + MeshyTextureNode, + ] + + +async def comfy_entrypoint() -> MeshyExtension: + return MeshyExtension() diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index cea0d1203..2794be35c 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -43,7 +43,7 @@ class UploadResponse(BaseModel): async def upload_images_to_comfyapi( cls: type[IO.ComfyNode], - image: torch.Tensor, + image: torch.Tensor | list[torch.Tensor], *, max_images: int = 8, mime_type: str | None = None, @@ -55,15 +55,28 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ + tensors: list[torch.Tensor] = [] + if isinstance(image, list): + for img in image: + is_batch = len(img.shape) > 3 + if is_batch: + tensors.extend(img[i] for i in range(img.shape[0])) + else: + tensors.append(img) + else: + is_batch = len(image.shape) > 3 + if is_batch: + tensors.extend(image[i] for i in range(image.shape[0])) + else: + tensors.append(image) + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - num_to_upload = min(batch_len, max_images) + num_to_upload = min(len(tensors), max_images) batch_start_ts = time.monotonic() for idx in range(num_to_upload): - tensor = image[idx] if is_batch else image + tensor = tensors[idx] img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label diff --git a/nodes.py b/nodes.py index aa8572446..f19d5fd1c 100644 --- a/nodes.py +++ b/nodes.py @@ -2401,6 +2401,7 @@ async def init_builtin_api_nodes(): "nodes_sora.py", "nodes_topaz.py", "nodes_tripo.py", + "nodes_meshy.py", "nodes_moonvalley.py", "nodes_rodin.py", "nodes_gemini.py", From 80441eb15e807aa280fb462cbb43d14191344ba4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:53:16 -0800 Subject: [PATCH 077/119] utils: fix lanczos grayscale upscaling (#11873) --- comfy/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index fac13f128..2e33a4258 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -929,7 +929,9 @@ def bislerp(samples, width, height): return result.to(orig_dtype) def lanczos(samples, width, height): - images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + #the below API is strict and expects grayscale to be squeezed + samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) + images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) From be518db5a7daa6010fb1c312c0832b9833a71d10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:54:04 +0200 Subject: [PATCH 078/119] Remove extraneous clip missing warnings when loading LTX2 embeddings_connector weights (#11874) --- comfy/text_encoders/lt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 776e25e97..c33c77db7 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,8 +118,9 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - - return self.load_state_dict(sdo, strict=False) + missing, unexpected = self.load_state_dict(sdo, strict=False) + missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model + return (missing, unexpected) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 3b832231bb81024d80bbe31b7d7e51e07b633beb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 Jan 2026 07:33:15 -0800 Subject: [PATCH 079/119] Flux2 Klein support. (#11890) --- comfy/sd.py | 15 +++++++-- comfy/text_encoders/flux.py | 59 +++++++++++++++++++++++++++++++++++- comfy/text_encoders/llama.py | 31 +++++++++++++++++++ 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..77700dfd3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1014,6 +1014,7 @@ class CLIPType(Enum): KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 NEWBIE = 24 + FLUX2 = 25 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1046,6 +1047,7 @@ class TEModel(Enum): QWEN3_2B = 17 GEMMA_3_12B = 18 JINA_CLIP_2 = 19 + QWEN3_8B = 20 def detect_te_model(sd): @@ -1089,6 +1091,8 @@ def detect_te_model(sd): return TEModel.QWEN3_4B elif weight.shape[0] == 2048: return TEModel.QWEN3_2B + elif weight.shape[0] == 4096: + return TEModel.QWEN3_8B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1214,11 +1218,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) elif te_model == TEModel.QWEN3_4B: - clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer + else: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.QWEN3_8B: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 21d93d757..4075afca4 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -3,7 +3,7 @@ import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast, LlamaTokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer import torch import os import json @@ -172,3 +172,60 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): model_options["num_layers"] = 30 super().__init__(device=device, dtype=dtype, model_options=model_options) return Flux2TEModel_ + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class KleinTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): + if name == "qwen3_4b": + tokenizer = Qwen3Tokenizer + elif name == "qwen3_8b": + tokenizer = Qwen3Tokenizer8B + + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class KleinTokenizer8B(KleinTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name) + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_8BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"): + if model_type == "qwen3_4b": + model = Qwen3_4BModel + elif model_type == "qwen3_8b": + model = Qwen3_8BModel + + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) + return Flux2TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 76731576b..331a30f61 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -99,6 +99,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_8BConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Ovis25_2BConfig: vocab_size: int = 151936 @@ -628,6 +650,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_8B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_8BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Ovis25_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 8f40b43e0204d5b9780f3e9618e140e929e80594 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 15 Jan 2026 10:57:35 -0500 Subject: [PATCH 080/119] ComfyUI v0.9.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0c9871e35..dbb57b4e5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.1" +__version__ = "0.9.2" diff --git a/pyproject.toml b/pyproject.toml index dc52218b4..9ea73da05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.1" +version = "0.9.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 12918a5f789d11c7d3c9d9f732891337740fe96f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 16 Jan 2026 03:08:21 +0800 Subject: [PATCH 081/119] chore: update workflow templates to v0.8.7 (#11896) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8650d28ec..624cd067b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.4 +comfyui-workflow-templates==0.8.7 comfyui-embedded-docs==0.4.0 torch torchsde From 6125b3a5e7215bf01874e402525552a7f5657a41 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 16 Jan 2026 05:12:13 +0800 Subject: [PATCH 082/119] Update workflow templates to v0.8.10 (#11899) * chore: update workflow templates to v0.8.9 * Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 624cd067b..996701550 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.7 +comfyui-workflow-templates==0.8.10 comfyui-embedded-docs==0.4.0 torch torchsde From 4c816d5c698dafaa31f8fc2c08ab1d81f9bc3239 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:06:40 -0800 Subject: [PATCH 083/119] Adjust memory usage factor calculation for flux2 klein. (#11900) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2c4c6b8fc..c8a7f6efb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -763,7 +763,7 @@ class Flux2(Flux): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604) def get_model(self, state_dict, prefix="", device=None): out = model_base.Flux2(self, device=device) From 732b707397922dbbec5ed04ecca3c773c878c64e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 15 Jan 2026 20:15:15 -0800 Subject: [PATCH 084/119] Added try-except around seed_assets call in get_object_info with a logging statement (#11901) --- server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index da2baefd4..04a577488 100644 --- a/server.py +++ b/server.py @@ -686,7 +686,10 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - seed_assets(["models"]) + try: + seed_assets(["models"]) + except Exception as e: + logging.error(f"Failed to seed assets: {e}") with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: From 9125613b53fc6af219d5a3db1d5b202ccc3f41b3 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 16 Jan 2026 08:09:07 +0200 Subject: [PATCH 085/119] feat(api-nodes): extend ByteDance nodes with seedance-1-5-pro model (#11871) --- comfy_api_nodes/apis/bytedance_api.py | 7 ++ comfy_api_nodes/nodes_bytedance.py | 104 +++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index b8c2f618b..400648cca 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -65,11 +65,13 @@ class TaskImageContent(BaseModel): class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(...) class Image2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + generate_audio: bool | None = Field(...) class TaskCreationResponse(BaseModel): @@ -141,4 +143,9 @@ VIDEO_TASKS_EXECUTION_TIME = { "720p": 65, "1080p": 100, }, + "seedance-1-5-pro-251215": { + "480p": 80, + "720p": 100, + "1080p": 150, + }, } diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f09a4a0ed..9cb1ca004 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -477,7 +477,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-t2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -528,6 +533,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -552,7 +563,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) @@ -567,7 +581,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ) return await process_video_task( cls, - payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, + ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -584,7 +602,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-i2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -639,6 +662,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -664,7 +693,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) @@ -686,6 +718,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -703,7 +736,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( @@ -762,6 +795,12 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -788,7 +827,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): @@ -821,6 +863,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -896,7 +939,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, + ), ) @classmethod @@ -967,10 +1044,15 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: PRICE_BADGE_VIDEO = IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]), expr=""" ( $priceByModel := { + "seedance-1-5-pro": { + "480p":[0.12,0.12], + "720p":[0.26,0.26], + "1080p":[0.58,0.59] + }, "seedance-1-0-pro": { "480p":[0.23,0.24], "720p":[0.51,0.56], @@ -989,6 +1071,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( }; $model := widgets.model; $modelKey := + $contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" : $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : "seedance-1-0-lite"; @@ -1002,11 +1085,12 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( $min10s := $baseRange[0]; $max10s := $baseRange[1]; $scale := widgets.duration / 10; - $minCost := $min10s * $scale; - $maxCost := $max10s * $scale; + $audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1; + $minCost := $min10s * $scale * $audioMultiplier; + $maxCost := $max10s * $scale * $audioMultiplier; ($minCost = $maxCost) - ? {"type":"usd","usd": $minCost} - : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ? {"type":"usd","usd": $minCost, "format": { "approximate": true }} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }} ) """, ) From 0c6b36c6ac1c34515cdf28f777a63074cd6d563d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 17 Jan 2026 06:22:50 +0800 Subject: [PATCH 086/119] chore: update workflow templates to v0.8.11 (#11918) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 996701550..3876274f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.10 +comfyui-workflow-templates==0.8.11 comfyui-embedded-docs==0.4.0 torch torchsde From 7ac999bf3069b06648a749212f59237080a75591 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:02:28 -0800 Subject: [PATCH 087/119] Add image sizes to clip vision outputs. (#11923) --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 66f2a9d9c..b28bf636c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -66,6 +66,7 @@ class ClipVisionModel(): outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0] if self.return_all_hidden_states: all_hs = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = all_hs[:, -2] From 00c775950aec5c563f532c8db08dae5e6adc24eb Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Sun, 18 Jan 2026 01:18:04 +0000 Subject: [PATCH 088/119] Update readme rdna3 nightly url (#11937) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e25f3cda7..123cc9472 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ These have less hardware support than the builds above but they work on windows. RDNA 3 (RX 7000 series): -```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/``` RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): From 0fd10ffa09588e0fc7f576ab7d0c93e97ad5fbb0 Mon Sep 17 00:00:00 2001 From: Theephop <144770658+TheephopWS@users.noreply.github.com> Date: Sun, 18 Jan 2026 09:18:24 +0800 Subject: [PATCH 089/119] fix: use .cpu() for waveform conversion in AudioFrame creation (#11787) --- comfy_api/latest/_input_impl/video_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index ea35c6062..1405d0b81 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -374,7 +374,7 @@ class VideoFromComponents(VideoInput): if audio_stream and self.__components.audio: waveform = self.__components.audio['waveform'] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) From 190c4416cce3b3b97b628935e001d796d565bfc9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 17 Jan 2026 18:20:35 -0800 Subject: [PATCH 090/119] Bump comfy-kitchen dependency to version 0.2.7 (#11941) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3876274f9..622256973 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.6 +comfy-kitchen>=0.2.7 #non essential dependencies: kornia>=0.7.1 From ac26065e6125871e2a742db6960f183fa037a75d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:52:45 +0200 Subject: [PATCH 091/119] chore(api-nodes): remove non-used; extract model to separate files (#11927) * chore(api-nodes): remove non-used; extract model to separate files * chore(api-nodes): remove non-needed prefix in filenames --- comfy_api_nodes/README.md | 65 ---- comfy_api_nodes/apis/{bfl_api.py => bfl.py} | 0 .../apis/{bytedance_api.py => bytedance.py} | 0 .../apis/{gemini_api.py => gemini.py} | 0 comfy_api_nodes/apis/ideogram.py | 292 +++++++++++++++++ .../apis/{kling_api.py => kling.py} | 0 comfy_api_nodes/apis/{luma_api.py => luma.py} | 0 .../apis/{minimax_api.py => minimax.py} | 0 comfy_api_nodes/apis/moonvalley.py | 152 +++++++++ comfy_api_nodes/apis/openai.py | 170 ++++++++++ comfy_api_nodes/apis/openai_api.py | 52 --- .../apis/{pixverse_api.py => pixverse.py} | 0 .../apis/{recraft_api.py => recraft.py} | 0 .../apis/{rodin_api.py => rodin.py} | 0 comfy_api_nodes/apis/runway.py | 127 ++++++++ .../apis/{stability_api.py => stability.py} | 0 .../apis/{topaz_api.py => topaz.py} | 4 +- .../apis/{tripo_api.py => tripo.py} | 0 comfy_api_nodes/apis/{veo_api.py => veo.py} | 0 comfy_api_nodes/mapper_utils.py | 116 ------- comfy_api_nodes/nodes_bfl.py | 2 +- comfy_api_nodes/nodes_bytedance.py | 2 +- comfy_api_nodes/nodes_gemini.py | 2 +- comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 2 +- comfy_api_nodes/nodes_openai.py | 86 ++--- comfy_api_nodes/nodes_pixverse.py | 2 +- comfy_api_nodes/nodes_recraft.py | 2 +- comfy_api_nodes/nodes_rodin.py | 2 +- comfy_api_nodes/nodes_runway.py | 2 +- comfy_api_nodes/nodes_stability.py | 2 +- comfy_api_nodes/nodes_topaz.py | 55 ++-- comfy_api_nodes/nodes_tripo.py | 2 +- comfy_api_nodes/nodes_veo2.py | 2 +- comfy_api_nodes/redocly-dev.yaml | 10 - comfy_api_nodes/redocly.yaml | 10 - .../comfy_api_nodes_test/mapper_utils_test.py | 297 ------------------ 40 files changed, 825 insertions(+), 641 deletions(-) delete mode 100644 comfy_api_nodes/README.md rename comfy_api_nodes/apis/{bfl_api.py => bfl.py} (100%) rename comfy_api_nodes/apis/{bytedance_api.py => bytedance.py} (100%) rename comfy_api_nodes/apis/{gemini_api.py => gemini.py} (100%) create mode 100644 comfy_api_nodes/apis/ideogram.py rename comfy_api_nodes/apis/{kling_api.py => kling.py} (100%) rename comfy_api_nodes/apis/{luma_api.py => luma.py} (100%) rename comfy_api_nodes/apis/{minimax_api.py => minimax.py} (100%) create mode 100644 comfy_api_nodes/apis/moonvalley.py create mode 100644 comfy_api_nodes/apis/openai.py delete mode 100644 comfy_api_nodes/apis/openai_api.py rename comfy_api_nodes/apis/{pixverse_api.py => pixverse.py} (100%) rename comfy_api_nodes/apis/{recraft_api.py => recraft.py} (100%) rename comfy_api_nodes/apis/{rodin_api.py => rodin.py} (100%) create mode 100644 comfy_api_nodes/apis/runway.py rename comfy_api_nodes/apis/{stability_api.py => stability.py} (100%) rename comfy_api_nodes/apis/{topaz_api.py => topaz.py} (97%) rename comfy_api_nodes/apis/{tripo_api.py => tripo.py} (100%) rename comfy_api_nodes/apis/{veo_api.py => veo.py} (100%) delete mode 100644 comfy_api_nodes/mapper_utils.py delete mode 100644 comfy_api_nodes/redocly-dev.yaml delete mode 100644 comfy_api_nodes/redocly.yaml delete mode 100644 tests-unit/comfy_api_nodes_test/mapper_utils_test.py diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md deleted file mode 100644 index f56d6c860..000000000 --- a/comfy_api_nodes/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# ComfyUI API Nodes - -## Introduction - -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). - -## Development - -While developing, you should be testing against the Staging environment. To test against staging: - -**Install ComfyUI_frontend** - -Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. - -> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. - -```bash -python run main.py --comfy-api-base https://stagingapi.comfy.org -``` - -To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. - -API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. - -### Redocly Instructions - -**Tip** -When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. - -Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. - -```bash -# Download the OpenAPI file from staging server. -curl -o openapi.yaml https://stagingapi.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` - - -# Merging to Master - -Before merging to comfyanonymous/ComfyUI master, follow these steps: - -1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. -1. Make sure the ComfyUI API is deployed to prod with your changes. -1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. - -```bash -# Download the OpenAPI file from prod server. -curl -o openapi.yaml https://api.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl.py similarity index 100% rename from comfy_api_nodes/apis/bfl_api.py rename to comfy_api_nodes/apis/bfl.py diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance.py similarity index 100% rename from comfy_api_nodes/apis/bytedance_api.py rename to comfy_api_nodes/apis/bytedance.py diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini.py similarity index 100% rename from comfy_api_nodes/apis/gemini_api.py rename to comfy_api_nodes/apis/gemini.py diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py new file mode 100644 index 000000000..737e18e3b --- /dev/null +++ b/comfy_api_nodes/apis/ideogram.py @@ -0,0 +1,292 @@ +from enum import Enum +from typing import Optional, List, Dict, Any, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel, StrictBytes + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[int] = Field( + None, + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class Datum1(BaseModel): + is_image_safe: Optional[bool] = None + prompt: Optional[str] = None + resolution: Optional[str] = None + seed: Optional[int] = None + style_type: Optional[str] = None + url: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + +class IdeogramV3ReframeRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' + + +class RenderingSpeed(str, Enum): + DEFAULT = 'DEFAULT' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + rendering_speed: RenderingSpeed + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) + + +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling.py similarity index 100% rename from comfy_api_nodes/apis/kling_api.py rename to comfy_api_nodes/apis/kling.py diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma.py similarity index 100% rename from comfy_api_nodes/apis/luma_api.py rename to comfy_api_nodes/apis/luma.py diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax.py similarity index 100% rename from comfy_api_nodes/apis/minimax_api.py rename to comfy_api_nodes/apis/minimax.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py new file mode 100644 index 000000000..7ec7a4ade --- /dev/null +++ b/comfy_api_nodes/apis/moonvalley.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, StrictBytes + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py new file mode 100644 index 000000000..b85ef252b --- /dev/null +++ b/comfy_api_nodes/apis/openai.py @@ -0,0 +1,170 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = Field(None) + text_tokens: int | None = Field(None) + + +class Usage(BaseModel): + input_tokens: int | None = Field(None) + input_tokens_details: InputTokensDetails | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = Field(None) + usage: Usage | None = Field(None) + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") + + +class ModelResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + top_p: float | None = Field( + 1, + description="Controls diversity of the response via nucleus sampling", + ge=0.0, + le=1.0, + ) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + previous_response_id: str | None = Field(None) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description="The number of reasoning tokens.") + + +class CachedTokensDetails(BaseModel): + cached_tokens: int = Field( + ..., + description="The number of tokens that were retrieved from the cache.", + ) + + +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description="The number of input tokens.") + input_tokens_details: CachedTokensDetails = Field(...) + output_tokens: int = Field(..., description="The number of output tokens.") + output_tokens_details: OutputTokensDetails = Field(...) + total_tokens: int = Field(..., description="The total number of tokens used.") + + +class InputTextContent(BaseModel): + text: str = Field(..., description="The text input to the model.") + type: str = Field("input_text") + + +class OutputContent(BaseModel): + type: str = Field(..., description="The type of output content") + text: str | None = Field(None, description="The text content") + data: str | None = Field(None, description="Base64-encoded audio data") + transcript: str | None = Field(None, description="Transcript of the audio") + + +class OutputMessage(BaseModel): + type: str = Field(..., description="The type of output item") + content: list[OutputContent] | None = Field(None, description="The content of the message") + role: str | None = Field(None, description="The role of the message") + + +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: float | None = Field( + None, + description="Unix timestamp (in seconds) of when this Response was created.", + ) + error: ResponseError | None = Field(None) + id: str | None = Field(None, description="Unique identifier for this Response.") + object: str | None = Field(None, description="The object type of this resource - always set to `response`.") + output: list[OutputMessage] | None = Field(None) + parallel_tool_calls: bool | None = Field(True) + status: str | None = Field( + None, + description="One of `completed`, `failed`, `in_progress`, or `incomplete`.", + ) + usage: ResponseUsage | None = Field(None) + + +class InputImageContent(BaseModel): + detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.") + file_id: str | None = Field(None) + image_url: str | None = Field(None) + type: str = Field(..., description="The type of the input item. Always `input_image`.") + + +class InputFileContent(BaseModel): + file_data: str | None = Field(None) + file_id: str | None = Field(None) + filename: str | None = Field(None, description="The name of the file to be sent to the model.") + type: str = Field(..., description="The type of the input item. Always `input_file`.") + + +class InputMessage(BaseModel): + content: list[InputTextContent | InputImageContent | InputFileContent] = Field( + ..., + description="A list of one or many input items to the model, containing different content types.", + ) + role: str | None = Field(None) + type: str | None = Field(None) + + +class OpenAICreateResponse(ModelResponseProperties, ResponseProperties): + include: str | None = Field(None) + input: list[InputMessage] = Field(...) + parallel_tool_calls: bool | None = Field( + True, description="Whether to allow the model to run tool calls in parallel." + ) + store: bool | None = Field( + True, + description="Whether to store the generated model response for later retrieval via API.", + ) + stream: bool | None = Field(False) + usage: ResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py deleted file mode 100644 index ae5bb2673..000000000 --- a/comfy_api_nodes/apis/openai_api.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel, Field - - -class Datum2(BaseModel): - b64_json: str | None = Field(None, description="Base64 encoded image data") - revised_prompt: str | None = Field(None, description="Revised prompt") - url: str | None = Field(None, description="URL of the image") - - -class InputTokensDetails(BaseModel): - image_tokens: int | None = None - text_tokens: int | None = None - - -class Usage(BaseModel): - input_tokens: int | None = None - input_tokens_details: InputTokensDetails | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class OpenAIImageGenerationResponse(BaseModel): - data: list[Datum2] | None = None - usage: Usage | None = None - - -class OpenAIImageEditRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str = Field(...) - moderation: str | None = Field(None) - n: int | None = Field(None, description="The number of images to generate") - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - size: str | None = Field(None, description="Size of the output image") - - -class OpenAIImageGenerationRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str | None = Field(None) - moderation: str | None = Field(None) - n: int | None = Field( - None, - description="The number of images to generate.", - ) - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="The quality of the generated image") - size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse.py similarity index 100% rename from comfy_api_nodes/apis/pixverse_api.py rename to comfy_api_nodes/apis/pixverse.py diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft.py similarity index 100% rename from comfy_api_nodes/apis/recraft_api.py rename to comfy_api_nodes/apis/recraft.py diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin.py similarity index 100% rename from comfy_api_nodes/apis/rodin_api.py rename to comfy_api_nodes/apis/rodin.py diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py new file mode 100644 index 000000000..df6f2b845 --- /dev/null +++ b/comfy_api_nodes/apis/runway.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, List, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( + None, + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', + ge=0.0, + le=1.0, + ) + status: RunwayTaskStatusEnum + + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' + ) + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + + +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) + + +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability.py similarity index 100% rename from comfy_api_nodes/apis/stability_api.py rename to comfy_api_nodes/apis/stability.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz.py similarity index 97% rename from comfy_api_nodes/apis/topaz_api.py rename to comfy_api_nodes/apis/topaz.py index 4d9e62e72..a9e6235a7 100644 --- a/comfy_api_nodes/apis/topaz_api.py +++ b/comfy_api_nodes/apis/topaz.py @@ -41,7 +41,7 @@ class Resolution(BaseModel): height: int = Field(...) -class CreateCreateVideoRequestSource(BaseModel): +class CreateVideoRequestSource(BaseModel): container: str = Field(...) size: int = Field(..., description="Size of the video file in bytes") duration: int = Field(..., description="Duration of the video file in seconds") @@ -89,7 +89,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): - source: CreateCreateVideoRequestSource = Field(...) + source: CreateVideoRequestSource = Field(...) filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo.py similarity index 100% rename from comfy_api_nodes/apis/tripo_api.py rename to comfy_api_nodes/apis/tripo.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo.py similarity index 100% rename from comfy_api_nodes/apis/veo_api.py rename to comfy_api_nodes/apis/veo.py diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py deleted file mode 100644 index 6fab8f4bb..000000000 --- a/comfy_api_nodes/mapper_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from enum import Enum - -from pydantic.fields import FieldInfo -from pydantic import BaseModel -from pydantic_core import PydanticUndefined - -from comfy.comfy_types.node_typing import IO, InputTypeOptions - -NodeInput = tuple[IO, InputTypeOptions] - - -def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: - config = {} - if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: - config["default"] = field_info.default - if hasattr(field_info, "description") and field_info.description is not None: - config["tooltip"] = field_info.description - return config - - -def _get_number_constraints_config(field_info: FieldInfo) -> dict: - config = {} - if hasattr(field_info, "metadata"): - metadata = field_info.metadata - for constraint in metadata: - if hasattr(constraint, "ge"): - config["min"] = constraint.ge - if hasattr(constraint, "le"): - config["max"] = constraint.le - if hasattr(constraint, "multiple_of"): - config["step"] = constraint.multiple_of - return config - - -def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.IMAGE, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.STRING, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.FLOAT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.INT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_combo_input( - field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs -) -> NodeInput: - combo_config = {} - if enum_type is not None: - combo_config["options"] = [option.value for option in enum_type] - combo_config = { - **combo_config, - **_create_base_config(field_info), - **kwargs, - } - return IO.COMBO, combo_config - - -def model_field_to_node_input( - input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs -) -> NodeInput: - """ - Maps a field from a Pydantic model to a Comfy node input. - - Args: - input_type: The type of the input. - base_model: The Pydantic model to map the field from. - field_name: The name of the field to map. - **kwargs: Additional key/values to include in the input options. - - Note: - For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. - - Example: - >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) - >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) - >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) - """ - field_info: FieldInfo = base_model.model_fields[field_name] - result: NodeInput - - if input_type == IO.IMAGE: - result = _model_field_to_image_input(field_info, **kwargs) - elif input_type == IO.STRING: - result = _model_field_to_string_input(field_info, **kwargs) - elif input_type == IO.FLOAT: - result = _model_field_to_float_input(field_info, **kwargs) - elif input_type == IO.INT: - result = _model_field_to_int_input(field_info, **kwargs) - elif input_type == IO.COMBO: - result = _model_field_to_combo_input(field_info, **kwargs) - else: - message = f"Invalid input type: {input_type}" - raise ValueError(message) - - return result diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 76021ef7f..61c3b4503 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bfl_api import ( +from comfy_api_nodes.apis.bfl import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 9cb1ca004..486801150 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bytedance_api import ( +from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index a2daea50a..3b31caa7b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -14,7 +14,7 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input, Types -from comfy_api_nodes.apis.gemini_api import ( +from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 827b3523a..feaf7a858 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.ideogram import ( IdeogramGenerateRequest, IdeogramGenerateResponse, ImageRequest, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 05dde88b1..3ec71530b 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -49,7 +49,7 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.kling_api import ( +from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, MotionControlRequest, OmniImageParamImage, diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 95cb442e5..9ed6cd299 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.luma_api import ( +from comfy_api_nodes.apis.luma import ( LumaAspectRatio, LumaCharacterRef, LumaConceptChain, diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 43a15d50d..b5d0b461f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.minimax_api import ( +from comfy_api_nodes.apis.minimax import ( MinimaxFileRetrieveResponse, MiniMaxModel, MinimaxTaskResultResponse, diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 769b171b7..08315fa2b 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -3,7 +3,7 @@ import logging from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.moonvalley import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoRequest, diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 2f144c5c3..a12acc06b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -10,24 +10,18 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( - CreateModelResponseProperties, - Detail, - InputContent, +from comfy_api_nodes.apis.openai import ( InputFileContent, InputImageContent, InputMessage, - InputMessageContentList, InputTextContent, - Item, + ModelResponseProperties, OpenAICreateResponse, - OpenAIResponse, - OutputContent, -) -from comfy_api_nodes.apis.openai_api import ( OpenAIImageEditRequest, OpenAIImageGenerationRequest, OpenAIImageGenerationResponse, + OpenAIResponse, + OutputContent, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -266,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -384,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -500,8 +494,8 @@ class OpenAIGPTImage1(IO.ComfyNode): files = [] batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i: i + 1] - scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -523,7 +517,7 @@ class OpenAIGPTImage1(IO.ComfyNode): rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) @@ -696,29 +690,23 @@ class OpenAIChatNode(IO.ComfyNode): ) @classmethod - def get_message_content_from_response( - cls, response: OpenAIResponse - ) -> list[OutputContent]: + def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: - if output.root.type == "message": - return output.root.content + if output.type == "message": + return output.content raise TypeError("No output message found in response") @classmethod - def get_text_from_message_content( - cls, message_content: list[OutputContent] - ) -> str: + def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str: """Extract text content from message content.""" for content_item in message_content: - if content_item.root.type == "output_text": - return str(content_item.root.text) + if content_item.type == "output_text": + return str(content_item.text) return "No text output found in response" @classmethod - def tensor_to_input_image_content( - cls, image: torch.Tensor, detail_level: Detail = "auto" - ) -> InputImageContent: + def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( detail=detail_level, @@ -732,9 +720,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, image: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - ) -> InputMessageContentList: + ) -> list[InputTextContent | InputImageContent | InputFileContent]: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ + content_list: list[InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -746,13 +734,9 @@ class OpenAIChatNode(IO.ComfyNode): type="input_image", ) ) - if files is not None: content_list.extend(files) - - return InputMessageContentList( - root=content_list, - ) + return content_list @classmethod async def execute( @@ -762,7 +746,7 @@ class OpenAIChatNode(IO.ComfyNode): model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - advanced_options: CreateModelResponseProperties | None = None, + advanced_options: ModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -773,36 +757,28 @@ class OpenAIChatNode(IO.ComfyNode): response_model=OpenAIResponse, data=OpenAICreateResponse( input=[ - Item( - root=InputMessage( - content=cls.create_input_message_contents( - prompt, images, files - ), - role="user", - ) + InputMessage( + content=cls.create_input_message_contents(prompt, images, files), + role="user", ), ], store=True, stream=False, model=model, previous_response_id=None, - **( - advanced_options.model_dump(exclude_none=True) - if advanced_options - else {} - ), + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), ), ) response_id = create_response.id # Get result output result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"] - ) + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"], + ) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) @@ -923,7 +899,7 @@ class OpenAIChatConfig(IO.ComfyNode): remove depending on model choice. """ return IO.NodeOutput( - CreateModelResponseProperties( + ModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 86ddb3ab9..e17a24ae7 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.pixverse_api import ( +from comfy_api_nodes.apis.pixverse import ( PixverseTextVideoRequest, PixverseImageVideoRequest, PixverseTransitionVideoRequest, diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 05dc151ad..c01bcaece 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,7 +8,7 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.recraft_api import ( +from comfy_api_nodes.apis.recraft import ( RecraftColor, RecraftColorChain, RecraftControls, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index b4420cb93..3ffdc8b90 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -14,7 +14,7 @@ from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin_api import ( +from comfy_api_nodes.apis.rodin import ( Rodin3DGenerateRequest, Rodin3DGenerateResponse, Rodin3DCheckStatusRequest, diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index d19fdb365..573170ba2 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -16,7 +16,7 @@ from enum import Enum from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.runway import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 5c48c1f1e..5665109cf 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -3,7 +3,7 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.stability_api import ( +from comfy_api_nodes.apis.stability import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, StabilityAsyncResponse, diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 9dc5f45bc..c052e7656 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,7 +5,24 @@ import aiohttp from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.apis.topaz import ( + CreateVideoRequest, + CreateVideoRequestSource, + CreateVideoResponse, + ImageAsyncTaskResponse, + ImageDownloadResponse, + ImageEnhanceRequest, + ImageStatusResponse, + OutputInformationVideo, + Resolution, + VideoAcceptResponse, + VideoCompleteUploadRequest, + VideoCompleteUploadRequestPart, + VideoCompleteUploadResponse, + VideoEnhancementFilter, + VideoFrameInterpolationFilter, + VideoStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -153,13 +170,13 @@ class TopazImageEnhance(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") download_url = await upload_images_to_comfyapi( - cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096 + cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=topaz_api.ImageAsyncTaskResponse, - data=topaz_api.ImageEnhanceRequest( + response_model=ImageAsyncTaskResponse, + data=ImageEnhanceRequest( model=model, prompt=prompt, subject_detection=subject_detection, @@ -181,7 +198,7 @@ class TopazImageEnhance(IO.ComfyNode): await poll_op( cls, poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=topaz_api.ImageStatusResponse, + response_model=ImageStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, @@ -193,7 +210,7 @@ class TopazImageEnhance(IO.ComfyNode): results = await sync_op( cls, ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=topaz_api.ImageDownloadResponse, + response_model=ImageDownloadResponse, monitor_progress=False, ) return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) @@ -331,7 +348,7 @@ class TopazVideoEnhance(IO.ComfyNode): if target_height % 2 != 0: target_height += 1 filters.append( - topaz_api.VideoEnhancementFilter( + VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), @@ -340,7 +357,7 @@ class TopazVideoEnhance(IO.ComfyNode): if interpolation_enabled: target_frame_rate = interpolation_frame_rate filters.append( - topaz_api.VideoFrameInterpolationFilter( + VideoFrameInterpolationFilter( model=interpolation_model, slowmo=interpolation_slowmo, fps=interpolation_frame_rate, @@ -351,19 +368,19 @@ class TopazVideoEnhance(IO.ComfyNode): initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=topaz_api.CreateVideoResponse, - data=topaz_api.CreateVideoRequest( - source=topaz_api.CreateCreateVideoRequestSource( + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), frameCount=video.get_frame_count(), frameRate=src_frame_rate, - resolution=topaz_api.Resolution(width=src_width, height=src_height), + resolution=Resolution(width=src_width, height=src_height), ), filters=filters, - output=topaz_api.OutputInformationVideo( - resolution=topaz_api.Resolution(width=target_width, height=target_height), + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), frameRate=target_frame_rate, audioCodec="AAC", audioTransfer="Copy", @@ -379,7 +396,7 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/accept", method="PATCH", ), - response_model=topaz_api.VideoAcceptResponse, + response_model=VideoAcceptResponse, wait_label="Preparing upload", final_label_on_success="Upload started", ) @@ -402,10 +419,10 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", method="PATCH", ), - response_model=topaz_api.VideoCompleteUploadResponse, - data=topaz_api.VideoCompleteUploadRequest( + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( uploadResults=[ - topaz_api.VideoCompleteUploadRequestPart( + VideoCompleteUploadRequestPart( partNum=1, eTag=upload_etag, ), @@ -417,7 +434,7 @@ class TopazVideoEnhance(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=topaz_api.VideoStatusResponse, + response_model=VideoStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index aa790143d..5abf27b4d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.tripo_api import ( +from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, TripoConvertModelRequest, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index c14d6ad68..2a202fc3b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -4,7 +4,7 @@ from io import BytesIO from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.veo_api import ( +from comfy_api_nodes.apis.veo import ( VeoGenVidPollRequest, VeoGenVidPollResponse, VeoGenVidRequest, diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml deleted file mode 100644 index d9e3cab70..000000000 --- a/comfy_api_nodes/redocly-dev.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. -# This is used for development purposes to generate stubs for unreleased API endpoints. -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes'] - matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml deleted file mode 100644 index d102345b1..000000000 --- a/comfy_api_nodes/redocly.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. - -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes', 'Released'] - matchStrategy: all diff --git a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py deleted file mode 100644 index 69488f691..000000000 --- a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Optional -from enum import Enum - -from pydantic import BaseModel, Field - -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.mapper_utils import model_field_to_node_input - - -def test_model_field_to_float_input(): - """Tests mapping a float field with constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field( - default=0.5, - description="Flexibility in video generation", - ge=0.0, - le=1.0, - multiple_of=0.001, - ) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - "tooltip": "Flexibility in video generation", - "min": 0.0, - "max": 1.0, - "step": 0.001, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_float_input_no_constraints(): - """Tests mapping a float field with no constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field(default=0.5) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_int_input(): - """Tests mapping an int field with constraints.""" - - class ModelWithIntField(BaseModel): - num_frames: Optional[int] = Field( - default=10, - description="Number of frames to generate", - ge=1, - le=100, - multiple_of=1, - ) - - expected_output = ( - IO.INT, - { - "default": 10, - "tooltip": "Number of frames to generate", - "min": 1, - "max": 100, - "step": 1, - }, - ) - - actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input_multiline(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - "multiline": True, - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithStringField, "prompt", multiline=True - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input(): - """Tests mapping a combo field.""" - - class MockEnum(str, Enum): - option_1 = "option 1" - option_2 = "option 2" - option_3 = "option 3" - - class ModelWithComboField(BaseModel): - model_name: Optional[MockEnum] = Field("option 1", description="Model Name") - - expected_output = ( - IO.COMBO, - { - "options": ["option 1", "option 2", "option 3"], - "default": "option 1", - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input_no_options(): - """Tests mapping a combo field with no options.""" - - class ModelWithComboField(BaseModel): - model_name: Optional[str] = Field(description="Model Name") - - expected_output = ( - IO.COMBO, - { - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_image_input(): - """Tests mapping an image field.""" - - class ModelWithImageField(BaseModel): - image: Optional[str] = Field( - default=None, - description="An image for the video generation", - ) - - expected_output = ( - IO.IMAGE, - { - "default": None, - "tooltip": "An image for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_description(): - """Tests mapping a field with no description.""" - - class ModelWithNoDescriptionField(BaseModel): - field: Optional[str] = Field(default="default value") - - expected_output = ( - IO.STRING, - { - "default": "default value", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDescriptionField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_default(): - """Tests mapping a field with no default.""" - - class ModelWithNoDefaultField(BaseModel): - field: Optional[str] = Field(description="A field with no default") - - expected_output = ( - IO.STRING, - { - "tooltip": "A field with no default", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_metadata(): - """Tests mapping a field with no metadata or properties defined on the schema.""" - - class ModelWithNoMetadataField(BaseModel): - field: Optional[str] = Field() - - expected_output = ( - IO.STRING, - {}, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoMetadataField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_default_is_none(): - """ - Tests mapping a field with a default of `None`. - I.e., the default field should be included as the schema explicitly sets it to `None`. - """ - - class ModelWithNoneDefaultField(BaseModel): - field: Optional[str] = Field( - default=None, description="A field with a default of None" - ) - - expected_output = ( - IO.STRING, - { - "default": None, - "tooltip": "A field with a default of None", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoneDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] From f7ca41ff6226eecbf6c9ee475c1de714cb8f04e9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 04:57:57 +0200 Subject: [PATCH 092/119] chore(api-nodes): remove check for pyav>=14.2 in code (it was added to requirements.txt long ago) (#11934) --- comfy_api_nodes/canary.py | 10 ---------- nodes.py | 3 --- 2 files changed, 13 deletions(-) delete mode 100644 comfy_api_nodes/canary.py diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py deleted file mode 100644 index 4df7590b6..000000000 --- a/comfy_api_nodes/canary.py +++ /dev/null @@ -1,10 +0,0 @@ -import av - -ver = av.__version__.split(".") -if int(ver[0]) < 14: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -if int(ver[0]) == 14 and int(ver[1]) < 2: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -NODE_CLASS_MAPPINGS = {} diff --git a/nodes.py b/nodes.py index f19d5fd1c..8b5279b36 100644 --- a/nodes.py +++ b/nodes.py @@ -2409,9 +2409,6 @@ async def init_builtin_api_nodes(): "nodes_wan.py", ] - if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): - return api_nodes_files - import_failed = [] for node_file in api_nodes_files: if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): From a498556d0dcde3d7a7c19e1f5c733c8c2a2ffb10 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 17 Jan 2026 19:06:03 -0800 Subject: [PATCH 093/119] feat: add advanced parameter to Input classes for advanced widgets support (#11939) Add 'advanced' boolean parameter to Input and WidgetInput base classes and propagate to all typed Input subclasses (Boolean, Int, Float, String, Combo, MultiCombo, Webcam, MultiType, MatchType, ImageCompare). When set to True, the frontend will hide these inputs by default in a collapsible 'Advanced Inputs' section in the right side panel, reducing visual clutter for power-user options. This enables nodes to expose advanced configuration options (like encoding parameters, quality settings, etc.) without overwhelming typical users. Frontend support: ComfyUI_frontend PR #7812 --- comfy_api/latest/_io.py | 47 ++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e6a0d1821..c30d92aaa 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -153,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -162,6 +162,7 @@ class Input(_IO_V3): self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} self.rawLink = raw_link + self.advanced = advanced def as_dict(self): return prune_dict({ @@ -170,6 +171,7 @@ class Input(_IO_V3): "tooltip": self.tooltip, "lazy": self.lazy, "rawLink": self.rawLink, + "advanced": self.advanced, }) | prune_dict(self.extra_dict) def get_io_type(self): @@ -184,8 +186,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -242,8 +244,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.label_on = label_on self.label_off = label_off self.default: bool @@ -262,8 +264,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -288,8 +290,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -314,8 +316,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -350,12 +352,13 @@ class Combo(ComfyTypeIO): socketless: bool=None, extra_dict=None, raw_link: bool=None, + advanced: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,8 +390,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) + socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -421,9 +424,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) @comfytype(io_type="MASK") @@ -776,7 +779,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -789,7 +792,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self._io_types = types @property @@ -843,8 +846,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.template = template def as_dict(self): @@ -1119,8 +1122,8 @@ class ImageCompare(ComfyTypeI): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True): - super().__init__(id, display_name, optional, tooltip, None, None, socketless) + socketless: bool=True, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced) def as_dict(self): return super().as_dict() From 034fac70549dd9c35b155b80a3ff627ad07b1015 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 18 Jan 2026 08:40:39 +0200 Subject: [PATCH 094/119] chore(api-nodes): auto-discover all nodes_*.py files to avoid merge conflicts when adding new API nodes (#11943) --- nodes.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/nodes.py b/nodes.py index 8b5279b36..cba8eacc2 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import torch import os import sys import json +import glob import hashlib import inspect import traceback @@ -2384,35 +2385,12 @@ async def init_builtin_extra_nodes(): async def init_builtin_api_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") - api_nodes_files = [ - "nodes_ideogram.py", - "nodes_openai.py", - "nodes_minimax.py", - "nodes_veo2.py", - "nodes_kling.py", - "nodes_bfl.py", - "nodes_bytedance.py", - "nodes_ltxv.py", - "nodes_luma.py", - "nodes_recraft.py", - "nodes_pixverse.py", - "nodes_stability.py", - "nodes_runway.py", - "nodes_sora.py", - "nodes_topaz.py", - "nodes_tripo.py", - "nodes_meshy.py", - "nodes_moonvalley.py", - "nodes_rodin.py", - "nodes_gemini.py", - "nodes_vidu.py", - "nodes_wan.py", - ] + api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py"))) import_failed = [] for node_file in api_nodes_files: - if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): - import_failed.append(node_file) + if not await load_custom_node(node_file, module_parent="comfy_api_nodes"): + import_failed.append(os.path.basename(node_file)) return import_failed From 1a72bf20469dee31ad156f819c14f0172cbad222 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 18 Jan 2026 19:53:43 -0800 Subject: [PATCH 095/119] Readme update. (#11957) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 123cc9472..c56e05d07 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Works fully offline: core will never download anything unless you want to. -- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes` - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: @@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` From 866a4619db2db56c77a86e5fc9968a2454928627 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 06:21:35 +0800 Subject: [PATCH 096/119] chore: update workflow templates to v0.8.14 (#11974) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 622256973..312c7c137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.11 +comfyui-workflow-templates==0.8.14 comfyui-embedded-docs==0.4.0 torch torchsde From b931b37e30bb19b6e13ad8623e193ccdaf671a23 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:47:14 -0800 Subject: [PATCH 097/119] feat(api-nodes): add Bria Edit node (#11978) Co-authored-by: Alexander Piskun --- comfy_api_nodes/apis/bria.py | 61 +++++++++ comfy_api_nodes/nodes_bria.py | 198 ++++++++++++++++++++++++++++ comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 6 + 4 files changed, 267 insertions(+) create mode 100644 comfy_api_nodes/apis/bria.py create mode 100644 comfy_api_nodes/nodes_bria.py diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py new file mode 100644 index 000000000..9119cacc6 --- /dev/null +++ b/comfy_api_nodes/apis/bria.py @@ -0,0 +1,61 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputModerationSettings(TypedDict): + prompt_content_moderation: bool + visual_input_moderation: bool + visual_output_moderation: bool + + +class BriaEditImageRequest(BaseModel): + instruction: str | None = Field(...) + structured_instruction: str | None = Field( + ..., + description="Use this instead of instruction for precise, programmatic control.", + ) + images: list[str] = Field( + ..., + description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.", + ) + mask: str | None = Field( + None, + description="Mask image (black and white). Black areas will be preserved, white areas will be edited. " + "If omitted, the edit applies to the entire image. " + "The input image and the the input mask must be of the same size.", + ) + negative_prompt: str | None = Field(None) + guidance_scale: float = Field(...) + model_version: str = Field(...) + steps_num: int = Field(...) + seed: int = Field(...) + ip_signal: bool = Field( + False, + description="If true, returns a warning for potential IP content in the instruction.", + ) + prompt_content_moderation: bool = Field( + False, description="If true, returns 422 on instruction moderation failure." + ) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on images or mask moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + + +class BriaStatusResponse(BaseModel): + request_id: str = Field(...) + status_url: str = Field(...) + warning: str | None = Field(None) + + +class BriaResult(BaseModel): + structured_prompt: str = Field(...) + image_url: str = Field(...) + + +class BriaResponse(BaseModel): + status: str = Field(...) + result: BriaResult | None = Field(None) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py new file mode 100644 index 000000000..72a3055a7 --- /dev/null +++ b/comfy_api_nodes/nodes_bria.py @@ -0,0 +1,198 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bria import ( + BriaEditImageRequest, + BriaResponse, + BriaStatusResponse, + InputModerationSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + convert_mask_to_image, + download_url_to_image_tensor, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, +) + + +class BriaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaImageEditNode", + display_name="Bria Image Edit", + category="api node/image/Bria", + description="Edit images using Bria latest model", + inputs=[ + IO.Combo.Input("model", options=["FIBO"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.String.Input( + "structured_prompt", + multiline=True, + default="", + tooltip="A string containing the structured edit prompt in JSON format. " + "Use this instead of usual prompt for precise, programmatic control.", + ), + IO.Int.Input( + "seed", + default=1, + min=1, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Float.Input( + "guidance_scale", + default=3, + min=3, + max=5, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely.", + ), + IO.Int.Input( + "steps", + default=50, + min=20, + max=50, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "prompt_content_moderation", default=False + ), + IO.Boolean.Input( + "visual_input_moderation", default=False + ), + IO.Boolean.Input( + "visual_output_moderation", default=True + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Moderation settings", + ), + IO.Mask.Input( + "mask", + tooltip="If omitted, the edit applies to the entire image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(display_name="structured_prompt"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str, + structured_prompt: str, + seed: int, + guidance_scale: float, + steps: int, + moderation: InputModerationSettings, + mask: Input.Image | None = None, + ) -> IO.NodeOutput: + if not prompt and not structured_prompt: + raise ValueError( + "One of prompt or structured_prompt is required to be non-empty." + ) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + mask_url = None + if mask is not None: + mask_url = ( + await upload_images_to_comfyapi( + cls, + convert_mask_to_image(mask), + max_images=1, + mime_type="image/png", + wait_label="Uploading mask", + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), + data=BriaEditImageRequest( + instruction=prompt if prompt else None, + structured_instruction=structured_prompt if structured_prompt else None, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + wait_label="Uploading image", + ), + mask=mask_url, + negative_prompt=negative_prompt if negative_prompt else None, + guidance_scale=guidance_scale, + seed=seed, + model_version=model, + steps_num=steps, + prompt_content_moderation=moderation.get( + "prompt_content_moderation", False + ), + visual_input_content_moderation=moderation.get( + "visual_input_moderation", False + ), + visual_output_content_moderation=moderation.get( + "visual_output_moderation", False + ), + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaResponse, + ) + return IO.NodeOutput( + await download_url_to_image_tensor(response.result.image_url), + response.result.structured_prompt, + ) + + +class BriaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BriaImageEditNode, + ] + + +async def comfy_entrypoint() -> BriaExtension: + return BriaExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 4cc22abfb..364976000 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -11,6 +11,7 @@ from .conversions import ( audio_input_to_mp3, audio_to_base64_string, bytesio_to_image_tensor, + convert_mask_to_image, downscale_image_tensor, image_tensor_pair_to_batch, pil_to_bytesio, @@ -72,6 +73,7 @@ __all__ = [ "audio_input_to_mp3", "audio_to_base64_string", "bytesio_to_image_tensor", + "convert_mask_to_image", "downscale_image_tensor", "image_tensor_pair_to_batch", "pil_to_bytesio", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 99c302a2a..546741b7b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -451,6 +451,12 @@ def resize_mask_to_image( return mask +def convert_mask_to_image(mask: Input.Image) -> torch.Tensor: + """Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.""" + mask = mask.unsqueeze(-1) + return torch.cat([mask] * 3, dim=-1) + + def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: From 7458e20465a0efcf91eafc0c65d1929ab7b2238d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 19 Jan 2026 16:58:30 -0800 Subject: [PATCH 098/119] Make Autogrow validation work properly (#11977) * In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node) * Allow autogrow to work with all inputs being optional * Revert accidentally pushed changes to nodes_logic.py --- comfy_api/latest/_io.py | 53 ++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index c30d92aaa..4969d3506 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI): names = [f"{prefix}{i}" for i in range(max)] # need to create a new input based on the contents of input template_input = None - for _, dict_input in input.items(): - # for now, get just the first value from dict_input + template_required = True + for _input_type, dict_input in input.items(): + # for now, get just the first value from dict_input; if not required, min can be ignored + if len(dict_input) == 0: + continue template_input = list(dict_input.values())[0] + template_required = _input_type == "required" + break + if template_input is None: + raise Exception("template_input could not be determined from required or optional; this should never happen.") new_dict = {} + new_dict_added_to = False + # first, add possible inputs into out_dict for i, name in enumerate(names): expected_id = finalize_prefix(curr_prefix, name) + # required + if i < min and template_required: + out_dict["required"][expected_id] = template_input + type_dict = new_dict.setdefault("required", {}) + # optional + else: + out_dict["optional"][expected_id] = template_input + type_dict = new_dict.setdefault("optional", {}) if expected_id in live_inputs: - # required - if i < min: - type_dict = new_dict.setdefault("required", {}) - # optional - else: - type_dict = new_dict.setdefault("optional", {}) + # NOTE: prefix gets added in parse_class_inputs type_dict[name] = template_input + new_dict_added_to = True + # account for the edge case that all inputs are optional and no values are received + if not new_dict_added_to: + finalized_prefix = finalize_prefix(curr_prefix) + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @@ -1151,6 +1169,8 @@ class V3Data(TypedDict): 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + dynamic_paths_default_value: dict[str, Any] + 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' create_dynamic_tuple: bool 'When True, the value of the dynamic input will be in the format (value, path_key).' @@ -1504,6 +1524,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i "required": {}, "optional": {}, "dynamic_paths": {}, + "dynamic_paths_default_value": {}, } d = d.copy() # ignore hidden for parsing @@ -1513,8 +1534,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i out_dict["hidden"] = hidden v3_data = {} dynamic_paths = out_dict.pop("dynamic_paths", None) - if dynamic_paths is not None: + if dynamic_paths is not None and len(dynamic_paths) > 0: v3_data["dynamic_paths"] = dynamic_paths + # this list is used for autogrow, in the case all inputs are optional and no values are passed + dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) + if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: + v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value return out_dict, hidden, v3_data def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: @@ -1551,11 +1576,16 @@ def add_to_dict_v1(i: Input, d: dict): def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +class DynamicPathsDefaultValue: + EMPTY_DICT = "empty_dict" + def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) + default_value_dict = v3_data.get("dynamic_paths_default_value", {}) if paths is None: return values values = values.copy() + result = {} create_tuple = v3_data.get("create_dynamic_tuple", False) @@ -1569,6 +1599,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): if is_last: value = values.pop(key, None) + if value is None: + # see if a default value was provided for this key + default_option = default_value_dict.get(key, None) + if default_option == DynamicPathsDefaultValue.EMPTY_DICT: + value = {} if create_tuple: value = (value, key) current[p] = value From e0eacb06883c1f7ddf8af249cd461d7c2ebcbaae Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:00:36 -0800 Subject: [PATCH 099/119] Simpler way to implement the #11980 loras. (#11981) --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 2e33a4258..5e79fb449 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -639,6 +639,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "proj_out.bias": "linear2.bias", "attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale", + "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 + "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: From 0da5a0fe58ae940726a61b94698e303fb39d73c1 Mon Sep 17 00:00:00 2001 From: rkfg Date: Tue, 20 Jan 2026 06:12:02 +0300 Subject: [PATCH 100/119] Convert mono audio to fake stereo for LTXV VAE encoding (#11965) --- comfy/ldm/lightricks/vae/audio_vae.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index a9111d3bd..29d9e6c29 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module): waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: - raise ValueError( - f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" - ) + if waveform.shape[1] == 1: + waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:]) + else: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) mel_spec = self.preprocessor.waveform_to_mel( waveform, waveform_sample_rate, device=self.device_manager.load_device From 70c91b8248e08492cf16bfebdc83579b801a6ee0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:32:40 -0800 Subject: [PATCH 101/119] Fix #11963 (#11982) --- comfy/text_encoders/ovis.py | 1 + comfy/text_encoders/z_image.py | 1 + 2 files changed, 2 insertions(+) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 5754424d2..2cc0867c3 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return OvisTEModel_ diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 19adde0b7..ad41bfb1e 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return ZImageTEModel_ From 9d273d3ab1fb1d2c8b34de4d54cabe50a5a3e5bc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Jan 2026 22:40:18 -0500 Subject: [PATCH 102/119] ComfyUI v0.10.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index dbb57b4e5..952d413db 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.9.2" +__version__ = "0.10.0" diff --git a/pyproject.toml b/pyproject.toml index 9ea73da05..120b6c751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.9.2" +version = "0.10.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 2108167f9f70cfd4874945b31a916680f959a6d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 20:17:38 -0800 Subject: [PATCH 103/119] Support zimage omni base model. (#11979) --- comfy/ldm/lumina/model.py | 317 ++++++++++++++++++++++++++++------- comfy/model_base.py | 30 ++++ comfy/model_detection.py | 3 + comfy_extras/nodes_zimage.py | 88 ++++++++++ nodes.py | 1 + 5 files changed, 381 insertions(+), 58 deletions(-) create mode 100644 comfy_extras/nodes_zimage.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..139f879a1 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension +import comfy.utils -def modulate(x, scale): - return x * (1 + scale.unsqueeze(1)) +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + +def modulate(x, scale, timestep_zero_index=None): + if timestep_zero_index is None: + return x * (1 + scale.unsqueeze(1)) + else: + scale = (1 + scale.unsqueeze(1)) + actual_batch = scale.size(0) // 2 + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= scale[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= scale[:actual_batch] + return x + + +def apply_gate(gate, x, timestep_zero_index=None): + if timestep_zero_index is None: + return gate * x + else: + actual_batch = gate.size(0) // 2 + + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= gate[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= gate[:actual_batch] + return x ############################################################################# # Core NextDiT Model # @@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + timestep_zero_index=None, transformer_options={}, ): """ @@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module): assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2( clamp_fp16(self.attention( - modulate(self.attention_norm1(x), scale_msa), + modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index), x_mask, freqs_cis, transformer_options=transformer_options, - )) + ))), timestep_zero_index=timestep_zero_index ) - x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2( clamp_fp16(self.feed_forward( - modulate(self.ffn_norm1(x), scale_mlp), - )) + modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index), + ))), timestep_zero_index=timestep_zero_index ) else: assert adaln_input is None @@ -345,13 +389,37 @@ class FinalLayer(nn.Module): ), ) - def forward(self, x, c): + def forward(self, x, c, timestep_zero_index=None): scale = self.adaLN_modulation(c) - x = modulate(self.norm_final(x), scale) + x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index) x = self.linear(x) return x +def pad_zimage(feats, pad_token, pad_tokens_multiple): + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra + + +def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}): + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + return x_pos_ids + + class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -378,6 +446,7 @@ class NextDiT(nn.Module): time_scale=1.0, pad_tokens_multiple=None, clip_text_dim=None, + siglip_feat_dim=None, image_model=None, device=None, dtype=None, @@ -491,6 +560,41 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) + + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + siglip_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.siglip_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + # This norm final is in the lumina 2.0 code but isn't actually used for anything. # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) @@ -531,70 +635,166 @@ class NextDiT(nn.Module): imgs = torch.stack(imgs, dim=0) return imgs - def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} - ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - orig_x = x - - if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None): + if cap_feats is not None: + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + else: + cap_feats_len = 0 + cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) - cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + embeds = (cap_feats,) + freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + + def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}): + bsz = 1 + pH = pW = self.patch_size + device = x.device + embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) + + if not omni: + cap_feats_len = embeds[0].shape[1] + offset + embeds += (None,) + freqs_cis += (None,) + else: + cap_feats_len += offset + if siglip_feats is not None: + b, h, w, c = siglip_feats.shape + siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + siglip_pos_ids[:, :, 0] = cap_feats_len + 2 + siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten() + siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten() + if self.siglip_pad_token is not None: + siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check + siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) + else: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + + if siglip_feats is None: + embeds += (None,) + freqs_cis += (None,) + else: + embeds += (siglip_feats,) + freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),) B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - H_tokens, W_tokens = H // pH, W // pW - x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) - x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - + x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options) if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + embeds += (x,) + freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1 + + + def patchify_and_embed( + self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={} + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = x.shape[0] + cap_mask = None # TODO? + main_siglip = None + orig_x = x + + embeds = ([], [], []) + freqs_cis = ([], [], []) + leftover_cap = [] + + start_t = 0 + omni = len(ref_latents) > 0 + if omni: + for i, ref in enumerate(ref_latents): + if i < len(ref_contexts): + ref_con = ref_contexts[i] + else: + ref_con = None + if i < len(siglip_feats): + sig_feat = siglip_feats[i] + else: + sig_feat = None + + out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) + for i, e in enumerate(out[0]): + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + leftover_cap = ref_contexts[len(ref_latents):] + + H, W = x.shape[-2], x.shape[-1] + img_sizes = [(H, W)] * bsz + out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options) + img_len = out[0][-1].shape[1] + cap_len = out[0][0].shape[1] + for i, e in enumerate(out[0]): + if e is not None: + e = comfy.utils.repeat_to_batch_size(e, bsz) + embeds[i].append(e) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + + for cap in leftover_cap: + out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype) + cap_len += out[0][0].shape[1] + embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz)) + freqs_cis[0].append(out[1][0]) + start_t += out[2] patches = transformer_options.get("patches", {}) # refine context + cap_feats = torch.cat(embeds[0], dim=1) + cap_freqs_cis = torch.cat(freqs_cis[0], dim=1) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + + feats = (cap_feats,) + fc = (cap_freqs_cis,) + + if omni: + siglip_mask = None + siglip_feats_combined = torch.cat(embeds[1], dim=1) + siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) + if self.siglip_refiner is not None: + for layer in self.siglip_refiner: + siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options) + feats += (siglip_feats_combined,) + fc += (siglip_feats_freqs_cis,) padded_img_mask = None + x = torch.cat(embeds[-1], dim=1) + fc_x = torch.cat(freqs_cis[-1], dim=1) + if omni: + timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])] + else: + timestep_zero_index = None + x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] - padded_full_embed = torch.cat((cap_feats, x), dim=1) + padded_full_embed = torch.cat(feats + (x,), dim=1) + if timestep_zero_index is not None: + ind = padded_full_embed.shape[1] - x.shape[1] + timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])] + timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1])) + mask = None - img_sizes = [(H, W)] * bsz - l_effective_cap_len = [cap_feats.shape[1]] * bsz - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -604,7 +804,11 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -619,8 +823,6 @@ class NextDiT(nn.Module): t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: @@ -632,7 +834,7 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) transformer_options["total_blocks"] = len(self.layers) @@ -640,7 +842,7 @@ class NextDiT(nn.Module): img_input = img for i, layer in enumerate(self.layers): transformer_options["block_index"] = i - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) @@ -649,8 +851,7 @@ class NextDiT(nn.Module): if "txt" in out: img[:, :cap_size[0]] = out["txt"] - img = self.final_layer(img, adaln_input) + img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index) img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -img diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..28ba2643e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1150,6 +1150,7 @@ class CosmosPredict2(BaseModel): class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1169,6 +1170,35 @@ class Lumina2(BaseModel): if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni + if clip_vision_outputs is not None and len(clip_vision_outputs) > 0: + sigfeats = [] + for clip_vision_output in clip_vision_outputs: + if clip_vision_output is not None: + image_size = clip_vision_output.image_sizes[0] + shape = clip_vision_output.last_hidden_state.shape + sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1])) + if len(sigfeats) > 0: + out['siglip_feats'] = comfy.conds.CONDList(sigfeats) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_contexts = kwargs.get("reference_latents_text_embeds", None) + if ref_contexts is not None: + out['ref_contexts'] = comfy.conds.CONDList(ref_contexts) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index aff5a50b9..42884f797 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["time_scale"] = 1000.0 if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 + sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) + if sig_weight is not None: + dit_config["siglip_feat_dim"] = sig_weight.shape[0] return dit_config diff --git a/comfy_extras/nodes_zimage.py b/comfy_extras/nodes_zimage.py new file mode 100644 index 000000000..2ee3c43b1 --- /dev/null +++ b/comfy_extras/nodes_zimage.py @@ -0,0 +1,88 @@ +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import math +import comfy.utils + + +class TextEncodeZImageOmni(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeZImageOmni", + category="advanced/conditioning", + is_experimental=True, + inputs=[ + io.Clip.Input("clip"), + io.ClipVision.Input("image_encoder", optional=True), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Boolean.Input("auto_resize_images", default=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: + ref_latents = [] + images = list(filter(lambda a: a is not None, [image1, image2, image3])) + + prompt_list = [] + template = None + if len(images) > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1) + prompt_list += ["<|vision_end|><|im_end|>"] + template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>" + + encoded_images = [] + + for i, image in enumerate(images): + if image_encoder is not None: + encoded_images.append(image_encoder.encode_image(image)) + + if vae is not None: + if auto_resize_images: + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + ref_latents.append(vae.encode(image)) + + tokens = clip.tokenize(prompt, llama_template=template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + extra_text_embeds = [] + for p in prompt_list: + tokens = clip.tokenize(p, llama_template="{}") + text_embeds = clip.encode_from_tokens_scheduled(tokens) + extra_text_embeds.append(text_embeds[0][0]) + + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + if len(encoded_images) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True) + if len(extra_text_embeds) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True) + + return io.NodeOutput(conditioning) + + +class ZImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeZImageOmni, + ] + + +async def comfy_entrypoint() -> ZImageExtension: + return ZImageExtension() diff --git a/nodes.py b/nodes.py index cba8eacc2..ea5d6e525 100644 --- a/nodes.py +++ b/nodes.py @@ -2373,6 +2373,7 @@ async def init_builtin_extra_nodes(): "nodes_kandinsky5.py", "nodes_wanmove.py", "nodes_image_compare.py", + "nodes_zimage.py", ] import_failed = [] From 0fc3b6e3a6f1d8fdffca3a51cb4d10a06f4e079d Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 20 Jan 2026 12:17:56 +0800 Subject: [PATCH 104/119] chore: update workflow templates to v0.8.15 (#11984) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 312c7c137..35543525d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.14 -comfyui-workflow-templates==0.8.14 +comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch torchsde From 4edb87aa50190139a38a2ccd6b6ee35ba9df4da1 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 20 Jan 2026 13:57:50 +0900 Subject: [PATCH 105/119] Bump comfyui-frontend-package to 1.37.11 (#11976) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35543525d..ec89dccd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.36.14 +comfyui-frontend-package==1.37.11 comfyui-workflow-templates==0.8.15 comfyui-embedded-docs==0.4.0 torch From 8ccc0c94fa0d8e43fffe7190e6a36551a53df54a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:32:00 -0800 Subject: [PATCH 106/119] Make omni stuff work on regular z image for easier testing. (#11985) --- comfy/ldm/lumina/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 139f879a1..b114d9e31 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -657,7 +657,7 @@ class NextDiT(nn.Module): device = x.device embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) - if not omni: + if (not omni) or self.siglip_embedder is None: cap_feats_len = embeds[0].shape[1] + offset embeds += (None,) freqs_cis += (None,) @@ -675,8 +675,9 @@ class NextDiT(nn.Module): siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) else: - siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) - siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + if self.siglip_pad_token is not None: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) if siglip_feats is None: embeds += (None,) @@ -724,8 +725,9 @@ class NextDiT(nn.Module): out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) for i, e in enumerate(out[0]): - embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) - freqs_cis[i].append(out[1][i]) + if e is not None: + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) start_t = out[2] leftover_cap = ref_contexts[len(ref_latents):] @@ -759,7 +761,7 @@ class NextDiT(nn.Module): feats = (cap_feats,) fc = (cap_freqs_cis,) - if omni: + if omni and len(embeds[1]) > 0: siglip_mask = None siglip_feats_combined = torch.cat(embeds[1], dim=1) siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) From ddc541ffdae0fe626de5a33192001f31c6ab93c6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 20 Jan 2026 23:05:40 +0200 Subject: [PATCH 107/119] feat(api-nodes): add WaveSpeed nodes (#11945) --- comfy_api_nodes/apis/wavespeed.py | 35 ++++++ comfy_api_nodes/nodes_wavespeed.py | 178 +++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 comfy_api_nodes/apis/wavespeed.py create mode 100644 comfy_api_nodes/nodes_wavespeed.py diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py new file mode 100644 index 000000000..07a7bfa5d --- /dev/null +++ b/comfy_api_nodes/apis/wavespeed.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class SeedVR2ImageRequest(BaseModel): + image: str = Field(...) + target_resolution: str = Field(...) + output_format: str = Field("png") + enable_sync_mode: bool = Field(False) + + +class FlashVSRRequest(BaseModel): + target_resolution: str = Field(...) + video: str = Field(...) + duration: float = Field(...) + + +class TaskCreatedDataResponse(BaseModel): + id: str = Field(...) + + +class TaskCreatedResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreatedDataResponse | None = Field(None) + + +class TaskResultDataResponse(BaseModel): + status: str = Field(...) + outputs: list[str] = Field([]) + + +class TaskResultResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py new file mode 100644 index 000000000..c59fafd3b --- /dev/null +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -0,0 +1,178 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.wavespeed import ( + FlashVSRRequest, + TaskCreatedResponse, + TaskResultResponse, + SeedVR2ImageRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_video_duration, + upload_images_to_comfyapi, + get_number_of_images, + download_url_to_image_tensor, +) + + +class WavespeedFlashVSRNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedFlashVSRNode", + display_name="FlashVSR Video Upscale", + category="api node/video/WaveSpeed", + description="Fast, high-quality video upscaler that " + "boosts resolution and restores clarity for low-resolution or blurry footage.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), + expr=""" + ( + $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; + { + "type":"usd", + "usd": $lookup($price_for_1sec, widgets.target_resolution), + "format":{"suffix": "/second", "approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + target_resolution: str, + ) -> IO.NodeOutput: + validate_container_format_is_mp4(video) + validate_video_duration(video, min_duration=5, max_duration=60 * 10) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), + response_model=TaskCreatedResponse, + data=FlashVSRRequest( + target_resolution=target_resolution.lower(), + video=await upload_video_to_comfyapi(cls, video), + duration=video.get_duration(), + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) + + +class WavespeedImageUpscaleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedImageUpscaleNode", + display_name="WaveSpeed Image Upscale", + category="api node/image/WaveSpeed", + description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", + inputs=[ + IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), + IO.Image.Input("image"), + IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := {"seedvr2": 0.01, "ultimate": 0.06}; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + target_resolution: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if model == "SeedVR2": + model_path = "seedvr2/image" + else: + model_path = "ultimate-image-upscaler" + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), + response_model=TaskCreatedResponse, + data=SeedVR2ImageRequest( + target_resolution=target_resolution.lower(), + image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + +class WavespeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WavespeedFlashVSRNode, + WavespeedImageUpscaleNode, + ] + + +async def comfy_entrypoint() -> WavespeedExtension: + return WavespeedExtension() From 965d0ed509ce46a3328c342aee23a234ba6e4f88 Mon Sep 17 00:00:00 2001 From: Ivan Zorin Date: Wed, 21 Jan 2026 01:44:28 +0200 Subject: [PATCH 108/119] fix: remove normalization of audio in LTX Mel spectrogram creation (#11990) For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation. This aligs inference with training and prevents loud audio from being attenuated. --- comfy/ldm/lightricks/vae/audio_vae.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 29d9e6c29..55a074661 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -103,20 +103,10 @@ class AudioPreprocessor: return waveform return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) - @staticmethod - def normalize_amplitude( - waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 - ) -> torch.Tensor: - waveform = waveform - waveform.mean(dim=2, keepdim=True) - peak = torch.max(torch.abs(waveform)) + eps - scale = peak.clamp(max=max_amplitude) / peak - return waveform * scale - def waveform_to_mel( self, waveform: torch.Tensor, waveform_sample_rate: int, device ) -> torch.Tensor: waveform = self.resample(waveform, waveform_sample_rate) - waveform = self.normalize_amplitude(waveform) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.target_sample_rate, From c4a14df9a35336dbfff096683c5015ce726c269d Mon Sep 17 00:00:00 2001 From: Mylo <36931363+gitmylo@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:46:11 +0100 Subject: [PATCH 109/119] Dynamically detect chroma radiance patch size (#11991) --- comfy/model_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 42884f797..dad206a2f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 - dit_config["patch_size"] = 16 + dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1) dit_config["nerf_hidden_size"] = 64 dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 From e755268e7b7843695f52b87595afcb09c1e9fd87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 20 Jan 2026 20:08:31 -0800 Subject: [PATCH 110/119] Config for Qwen 3 0.6B model. (#11998) --- comfy/text_encoders/llama.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 331a30f61..3080a3e09 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -77,6 +77,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_06BConfig: + vocab_size: int = 151936 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -641,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() From 0fc15700be9b555f351034942b5bd7243bdf6bcc Mon Sep 17 00:00:00 2001 From: Markury Date: Tue, 20 Jan 2026 23:18:33 -0500 Subject: [PATCH 111/119] Add LyCoris LoKr MLP layer support for Flux2 (#11997) --- comfy/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 5e79fb449..d97d753e6 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -611,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr + "ff.linear_in.bias": "img_mlp.0.bias", + "ff.linear_out.weight": "img_mlp.2.weight", + "ff.linear_out.bias": "img_mlp.2.bias", + "ff_context.linear_in.weight": "txt_mlp.0.weight", + "ff_context.linear_in.bias": "txt_mlp.0.bias", + "ff_context.linear_out.weight": "txt_mlp.2.weight", + "ff_context.linear_out.bias": "txt_mlp.2.bias", "attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", From 451af7015435df22e6313ae79f25fe2ef336a96d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:03:45 +0200 Subject: [PATCH 112/119] fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002) --- comfy_api_nodes/nodes_vidu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 8edb02f39..b9114c4bb 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): "subjects", template=IO.Autogrow.TemplateNames( IO.Image.Input("reference_images"), - names=["subject1", "subject2", "subject3"], + names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"], min=1, ), tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " @@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): control_after_generate=True, ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), - IO.Combo.Input("resolution", options=["720p"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), IO.Combo.Input( "movement_amplitude", options=["auto", "small", "medium", "large"], From bdeac8897e522b9637a6a427fdc8a50a6abd6b20 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 21 Jan 2026 15:36:02 -0800 Subject: [PATCH 113/119] feat: Add search_aliases field to node schema (#12010) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add search_aliases field to node schema Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate). Changes: - Add `search_aliases: list[str]` to V3 Schema - Add `SEARCH_ALIASES` support for V1 nodes - Include field in `/object_info` response - Add aliases to high-priority core nodes V1 usage: ```python class MyNode: SEARCH_ALIASES = ["alt name", "synonym"] ``` V3 usage: ```python io.Schema( node_id="MyNode", search_aliases=["alt name", "synonym"], ... ) ``` ## Related PRs - Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this) - Docs: Comfy-Org/docs#XXXX (draft - merge after stable) * Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1 --- comfy_api/latest/_io.py | 4 ++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_preview_any.py | 1 + comfy_extras/nodes_string.py | 1 + comfy_extras/nodes_upscale_model.py | 1 + nodes.py | 15 +++++++++++++++ server.py | 2 ++ 7 files changed, 25 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4969d3506..a60020ca8 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1249,6 +1249,7 @@ class NodeInfoV1: experimental: bool=None api_node: bool=None price_badge: dict | None = None + search_aliases: list[str]=None @dataclass class NodeInfoV3: @@ -1346,6 +1347,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1483,6 +1486,7 @@ class Schema: api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, ) return info diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 2e559c35c..6011275d6 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -550,6 +550,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93..91502ebf2 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f62..a2d5f0d94 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ed587851c..97b9e948d 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode): node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", category="image/upscaling", + search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), io.Image.Input("image"), diff --git a/nodes.py b/nodes.py index ea5d6e525..67b61dcfe 100644 --- a/nodes.py +++ b/nodes.py @@ -70,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC): CATEGORY = "conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." + SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] def encode(self, clip, text): if clip is None: @@ -86,6 +87,7 @@ class ConditioningCombine: FUNCTION = "combine" CATEGORY = "conditioning" + SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) @@ -294,6 +296,7 @@ class VAEDecode: CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." + SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] def decode(self, vae, samples): latent = samples["samples"] @@ -346,6 +349,7 @@ class VAEEncode: FUNCTION = "encode" CATEGORY = "latent" + SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) @@ -581,6 +585,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." + SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -667,6 +672,7 @@ class LoraLoader: CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." + SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: @@ -814,6 +820,7 @@ class ControlNetLoader: FUNCTION = "load_controlnet" CATEGORY = "loaders" + SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -890,6 +897,7 @@ class ControlNetApplyAdvanced: FUNCTION = "apply_controlnet" CATEGORY = "conditioning/controlnet" + SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: @@ -1200,6 +1208,7 @@ class EmptyLatentImage: CATEGORY = "latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." + SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) @@ -1540,6 +1549,7 @@ class KSampler: CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) @@ -1604,6 +1614,7 @@ class SaveImage: CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." + SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -1640,6 +1651,8 @@ class PreviewImage(SaveImage): self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 + SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"] + @classmethod def INPUT_TYPES(s): return {"required": @@ -1658,6 +1671,7 @@ class LoadImage: } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" @@ -1810,6 +1824,7 @@ class ImageScale: FUNCTION = "upscale" CATEGORY = "image/upscaling" + SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"] def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: diff --git a/server.py b/server.py index 04a577488..1888745b7 100644 --- a/server.py +++ b/server.py @@ -682,6 +682,8 @@ class PromptServer(): if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE + + info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', []) return info @routes.get("/object_info") From abe2ec26a61ff670b9c0e71e4821c873368c8728 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:44:28 -0800 Subject: [PATCH 114/119] Support the Anima model. (#12012) --- comfy/ldm/anima/model.py | 202 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 22 ++++ comfy/model_detection.py | 2 + comfy/sd.py | 7 ++ comfy/supported_models.py | 33 +++++- comfy/text_encoders/anima.py | 61 +++++++++++ 6 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/anima/model.py create mode 100644 comfy/text_encoders/anima.py diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000..2e6ed58fa --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,202 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + x = self.in_proj(self.embed(target_input_ids)) + context = source_hidden_states + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids): + if text_ids is not None: + return self.llm_adapter(text_embeds, text_ids) + else: + return text_embeds diff --git a/comfy/model_base.py b/comfy/model_base.py index 28ba2643e..1d57562cc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -49,6 +49,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model import comfy.model_management import comfy.patcher_extension @@ -1147,6 +1148,27 @@ class CosmosPredict2(BaseModel): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) + if t5xxl_weights is not None: + cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + + if cross_attn.shape[1] < 512: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dad206a2f..b29a033cc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -550,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 diff --git a/comfy/sd.py b/comfy/sd.py index 77700dfd3..f7f6a44a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -57,6 +57,7 @@ import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie +import comfy.text_encoders.anima import comfy.model_patcher import comfy.lora @@ -1048,6 +1049,7 @@ class TEModel(Enum): GEMMA_3_12B = 18 JINA_CLIP_2 = 19 QWEN3_8B = 20 + QWEN3_06B = 21 def detect_te_model(sd): @@ -1093,6 +1095,8 @@ def detect_te_model(sd): return TEModel.QWEN3_2B elif weight.shape[0] == 4096: return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1233,6 +1237,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c8a7f6efb..70abebf46 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima from . import supported_models_base from . import latent_formats @@ -992,6 +993,36 @@ class CosmosT2IPredict2(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1551,6 +1582,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000..41f95bcb6 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,61 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ From f09904720dc8b56bc6823ebdaf5de69465448e46 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Jan 2026 20:01:35 -0800 Subject: [PATCH 115/119] Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) --- comfy_extras/nodes_easycache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdb..90d730df6 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") return easycache.apply_cache_diff(x, uuids) @@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values @@ -240,6 +242,9 @@ class EasyCacheHolder: return to_return.clone() return to_return + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): if self.first_cond_uuid in uuids: self.total_steps_skipped += 1 From 3365ad18a5e0c86b23c6272e5adcedd333fc45cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:03:51 +0200 Subject: [PATCH 116/119] Support LTX2 tiny vae (taeltx_2) (#11929) --- comfy/sd.py | 5 ++--- comfy/taesd/taehv.py | 53 ++++++++++++++++++++++++++++---------------- latent_preview.py | 2 +- nodes.py | 2 +- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f7f6a44a0..ce7e6bcff 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -636,14 +636,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ import logging default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/nodes.py b/nodes.py index 67b61dcfe..8864fda60 100644 --- a/nodes.py +++ b/nodes.py @@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s): From 245f6139b65899112d11ff294d36a820f2d69496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:05:06 +0200 Subject: [PATCH 117/119] More targeted embedding_connector loading for LTX2 text encoder (#11992) Reduces errors --- comfy/text_encoders/lt.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c33c77db7..e49161964 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - missing, unexpected = self.load_state_dict(sdo, strict=False) - missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model - return (missing, unexpected) + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + return (missing_all, unexpected_all) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 16b9aabd52c3b81b365fbf562bbcc4528111ef6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:09:48 +0200 Subject: [PATCH 118/119] Support Multi/InfiniteTalk (#10179) * re-init * Update model_multitalk.py * whitespace... * Update model_multitalk.py * remove print * this is redundant * remove import * Restore preview functionality * Move block_idx to transformer_options * Remove LoopingSamplerCustomAdvanced * Remove looping functionality, keep extension functionality * Update model_multitalk.py * Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn * Chunk attention map calculation for multiple speakers to reduce peak VRAM usage * Update model_multitalk.py * Add ModelPatch type back * Fix for latest upstream * Use DynamicCombo for cleaner node Basically just so that single_speaker mode hides mask inputs and 2nd audio input * Update nodes_wan.py --- comfy/ldm/wan/model.py | 17 +- comfy/ldm/wan/model_multitalk.py | 500 ++++++++++++++++++++++++++++++ comfy_api/latest/_io.py | 3 +- comfy_extras/nodes_model_patch.py | 41 +++ comfy_extras/nodes_wan.py | 169 +++++++++- 5 files changed, 727 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/wan/model_multitalk.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ea123acb4 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ class WanModel(torch.nn.Module): self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ class WanModel(torch.nn.Module): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ class VaceWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000..c9dd98c4d --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a60020ca8..2ec8d6e4b 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -2038,6 +2038,7 @@ __all__ = [ "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index f66d28fc9..82c4754a3 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -257,6 +258,14 @@ class ModelPatchLoader: if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -524,6 +533,38 @@ class USOStyleReference: return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d32aad98e..90deb0077 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,9 +8,10 @@ import comfy.latent_formats import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features +class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanInfiniteTalkToVideo", + category="conditioning/video_models", + inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.AudioEncoderOutput.Input("audio_encoder_output_1"), + io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), + io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), + io.Image.Input("previous_frames", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), + ], + ) + + @classmethod + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: + + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") + + if mode["mode"] == "two_speakers": + audio_encoder_output_2 = mode["audio_encoder_output_2"] + mask_1 = mode["mask_1"] + mask_2 = mode["mask_2"] + + if audio_encoder_output_2 is not None: + if mask_1 is None or mask_2 is None: + raise ValueError("Masks must be provided if two audio encoder outputs are used.") + + ref_masks = None + if mask_1 is not None and mask_2 is not None: + if audio_encoder_output_2 is None: + raise ValueError("Second audio encoder output must be provided if two masks are used.") + ref_masks = torch.cat([mask_1, mask_2]) + + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + model_patched = model.clone() + + encoded_audio_list = [] + seq_lengths = [] + + for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]: + if audio_encoder_output is None: + continue + all_layers = audio_encoder_output["encoded_audio_all_layers"] + encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512] + encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512] + encoded_audio_list.append(encoded_audio) + seq_lengths.append(encoded_audio.shape[0]) + + # Pad / combine depending on multi_audio_type + multi_audio_type = "add" + if len(encoded_audio_list) > 1: + if multi_audio_type == "para": + max_len = max(seq_lengths) + padded = [] + for emb in encoded_audio_list: + if emb.shape[0] < max_len: + pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype) + emb = torch.cat([emb, pad], dim=0) + padded.append(emb) + encoded_audio_list = padded + elif multi_audio_type == "add": + total_len = sum(seq_lengths) + full_list = [] + offset = 0 + for emb, seq_len in zip(encoded_audio_list, seq_lengths): + full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype) + full[offset:offset+seq_len] = emb + full_list.append(full) + offset += seq_len + encoded_audio_list = full_list + + token_ref_target_masks = None + if ref_masks is not None: + token_ref_target_masks = torch.nn.functional.interpolate( + ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] + token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) + + # when extending from previous frames + if previous_frames is not None: + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper + model_patched.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, + "infinite_talk_outer_sample", + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, + model_patch, + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension): WanHuMoImageToVideo, WanAnimateToVideo, Wan22ImageToVideoLatent, + WanInfiniteTalkToVideo, ] async def comfy_entrypoint() -> WanExtension: From 72f6be1690868af852a624084a949b785fc056ea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:42:04 +0200 Subject: [PATCH 119/119] chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022) --- comfy_api_nodes/nodes_bria.py | 2 +- comfy_api_nodes/nodes_openai.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 72a3055a7..d3a52bc1b 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="BriaImageEditNode", - display_name="Bria Image Edit", + display_name="Bria FIBO Image Edit", category="api node/image/Bria", description="Edit images using Bria latest model", inputs=[ diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a12acc06b..f05aaab7b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1", + display_name="OpenAI GPT Image 1.5", category="api node/image/OpenAI", - description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", + description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( "prompt", @@ -429,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "model", options=["gpt-image-1", "gpt-image-1.5"], + default="gpt-image-1.5", optional=True, ), ],