From 61b08d4ba65fec37070376bf50da3ec3c534e859 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Thu, 31 Jul 2025 07:25:56 +0800 Subject: [PATCH 001/325] Replace manual x * sigmoid(x) with torch silu in VAE nonlinearity (#9057) --- comfy/ldm/cosmos/cosmos_tokenizer/utils.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py index 3af8d0d05..ca993006f 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py @@ -58,7 +58,8 @@ def is_odd(n: int) -> bool: def nonlinearity(x): - return x * torch.sigmoid(x) + # x * sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8162742cf..5c0373b74 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): From 97eb256a355b434bbc96ec27bbce33dd10273857 Mon Sep 17 00:00:00 2001 From: guill Date: Wed, 30 Jul 2025 19:55:28 -0700 Subject: [PATCH 002/325] Add support for partial execution in backend (#9123) When a prompt is submitted, it can optionally include `partial_execution_targets` as a list of ids. If it does, rather than adding all outputs to the execution list, we add only those in the list. --- execution.py | 7 +- server.py | 7 +- tests/inference/test_async_nodes.py | 15 +- tests/inference/test_execution.py | 202 ++++++++++++++++-- .../testing-pack/specific_tests.py | 21 ++ 5 files changed, 233 insertions(+), 19 deletions(-) diff --git a/execution.py b/execution.py index 8a9663a7d..cde14c52f 100644 --- a/execution.py +++ b/execution.py @@ -7,7 +7,7 @@ import threading import time import traceback from enum import Enum -from typing import List, Literal, NamedTuple, Optional +from typing import List, Literal, NamedTuple, Optional, Union import asyncio import torch @@ -891,7 +891,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -async def validate_prompt(prompt_id, prompt): +async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt): return (False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: - outputs.add(x) + if partial_execution_list is None or x in partial_execution_list: + outputs.add(x) if len(outputs) == 0: error = { diff --git a/server.py b/server.py index f4de0079b..3e06d2fbb 100644 --- a/server.py +++ b/server.py @@ -681,7 +681,12 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) - valid = await execution.validate_prompt(prompt_id, prompt) + + partial_execution_targets = None + if "partial_execution_targets" in json_data: + partial_execution_targets = json_data["partial_execution_targets"] + + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py index b243bbca9..f029953dd 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/inference/test_async_nodes.py @@ -7,7 +7,7 @@ import subprocess from pytest import fixture from comfy_execution.graph_utils import GraphBuilder -from tests.inference.test_execution import ComfyClient +from tests.inference.test_execution import ComfyClient, run_warmup @pytest.mark.execution @@ -24,6 +24,7 @@ class TestAsyncNodes: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--cpu', ] use_lru, lru_size = request.param if use_lru: @@ -82,6 +83,9 @@ class TestAsyncNodes: def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that multiple async nodes execute in parallel.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -148,6 +152,9 @@ class TestAsyncNodes: def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with lazy evaluation.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_lazy") + g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -305,6 +312,9 @@ class TestAsyncNodes: def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): """Test that async nodes are properly cached.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_cache") + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) @@ -324,6 +334,9 @@ class TestAsyncNodes: def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes within dynamically generated prompts.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_dynamic") + g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 9d3d685cc..e7b29302e 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -15,10 +15,18 @@ import urllib.parse import urllib.error from comfy_execution.graph_utils import GraphBuilder, Node +def run_warmup(client, prefix="warmup"): + """Run a simple workflow to warm up the server.""" + warmup_g = GraphBuilder(prefix=prefix) + warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + warmup_g.node("PreviewImage", images=warmup_image.out(0)) + client.run(warmup_g) + class RunResult: def __init__(self, prompt_id: str): self.outputs: Dict[str,Dict] = {} self.runs: Dict[str,bool] = {} + self.cached: Dict[str,bool] = {} self.prompt_id: str = prompt_id def get_output(self, node: Node): @@ -27,6 +35,13 @@ class RunResult: def did_run(self, node: Node): return self.runs.get(node.id, False) + def was_cached(self, node: Node): + return self.cached.get(node.id, False) + + def was_executed(self, node: Node): + """Returns True if node was either run or cached""" + return self.did_run(node) or self.was_cached(node) + def get_images(self, node: Node): output = self.get_output(node) if output is None: @@ -51,8 +66,10 @@ class ComfyClient: ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) self.ws = ws - def queue_prompt(self, prompt): + def queue_prompt(self, prompt, partial_execution_targets=None): p = {"prompt": prompt, "client_id": self.client_id} + if partial_execution_targets is not None: + p["partial_execution_targets"] = partial_execution_targets data = json.dumps(p).encode('utf-8') req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) return json.loads(urllib.request.urlopen(req).read()) @@ -70,13 +87,13 @@ class ComfyClient: def set_test_name(self, name): self.test_name = name - def run(self, graph): + def run(self, graph, partial_execution_targets=None): prompt = graph.finalize() for node in graph.nodes.values(): if node.class_type == 'SaveImage': node.inputs['filename_prefix'] = self.test_name - prompt_id = self.queue_prompt(prompt)['prompt_id'] + prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id'] result = RunResult(prompt_id) while True: out = self.ws.recv() @@ -92,7 +109,10 @@ class ComfyClient: elif message['type'] == 'execution_error': raise Exception(message['data']) elif message['type'] == 'execution_cached': - pass # Probably want to store this off for testing + if message['data']['prompt_id'] == prompt_id: + cached_nodes = message['data'].get('nodes', []) + for node_id in cached_nodes: + result.cached[node_id] = True history = self.get_history(prompt_id)[prompt_id] for node_id in history['outputs']: @@ -130,6 +150,7 @@ class TestExecution: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--cpu', ] use_lru, lru_size = request.param if use_lru: @@ -498,12 +519,15 @@ class TestExecution: assert not result.did_run(test_node), "The execution should have been cached" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create sleep nodes for each duration - sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) - sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1) sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) # Add outputs to verify the execution @@ -515,10 +539,9 @@ class TestExecution: result = client.run(g) elapsed_time = time.time() - start_time - # The test should take around 0.4 seconds (the longest sleep duration) - # plus some overhead, but definitely less than the sum of all sleeps (0.9s) - # We'll allow for up to 0.8s total to account for overhead - assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s" + # The test should take around 3.0 seconds (the longest sleep duration) + # plus some overhead, but definitely less than the sum of all sleeps (9.0s) + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" # Verify that all nodes executed assert result.did_run(sleep_node1), "Sleep node 1 should have run" @@ -526,6 +549,9 @@ class TestExecution: assert result.did_run(sleep_node3), "Sleep node 3 should have run" def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder # Create input images with different values image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -537,9 +563,9 @@ class TestExecution: image1=image1.out(0), image2=image2.out(0), image3=image3.out(0), - sleep1=0.4, - sleep2=0.5, - sleep3=0.6) + sleep1=4.8, + sleep2=4.9, + sleep3=5.0) output = g.node("SaveImage", images=parallel_sleep.out(0)) start_time = time.time() @@ -548,7 +574,7 @@ class TestExecution: # Similar to the previous test, expect parallel execution of the sleep nodes # which should complete in less than the sum of all sleeps - assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" + assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s" # Verify the parallel sleep node executed assert result.did_run(parallel_sleep), "ParallelSleep node should have run" @@ -585,3 +611,151 @@ class TestExecution: assert len(images) == 2, "Should have 2 images" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" + + # Output nodes included in the partial execution list are executed + def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create two separate output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + assert result.was_executed(input1), "Input1 should have been executed (run or cached)" + assert result.was_executed(output1), "Output1 should have been executed (run or cached)" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Verify only output1 produced results + assert len(result.get_images(output1)) == 1, "Output1 should have produced an image" + assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image" + + # Output nodes NOT included in the partial execution list are NOT executed + def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create three output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + output3 = g.node("SaveImage", images=input3.out(0)) + + # Run with partial execution targeting only output1 and output3 + result = client.run(g, partial_execution_targets=[output1.id, output3.id]) + + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(input3), "Input3 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + assert result.was_executed(output3), "Output3 should have been executed" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Output nodes NOT in list ARE executed if necessary for nodes that are in the list + def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create a processing chain with an OUTPUT_NODE that has socket outputs + output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0) + + # Create another node that depends on the output_with_socket + dependent_node = g.node("TestLazyMixImages", + image1=output_with_socket.out(0), + image2=input1.out(0), + mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0)) + + # Create the final output + final_output = g.node("SaveImage", images=dependent_node.out(0)) + + # Run with partial execution targeting only the final output + result = client.run(g, partial_execution_targets=[final_output.id]) + + # All nodes should have been executed because they're dependencies + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)" + assert result.was_executed(dependent_node), "Dependent node should have been executed" + assert result.was_executed(final_output), "Final output should have been executed" + + # Lazy execution works with partial execution + def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create masks that will trigger different lazy execution paths + mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1 + mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images + + # Create two lazy mix nodes + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0)) + + output1 = g.node("SaveImage", images=lazy_mix1.out(0)) + output2 = g.node("SaveImage", images=lazy_mix2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + # For output1 path - only input1 should run due to lazy evaluation (mask=0.0) + assert result.was_executed(input1), "Input1 should have been executed" + assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)" + assert result.was_executed(mask1), "Mask1 should have been executed" + assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + + # Nothing from output2 path should run + assert not result.did_run(input3), "Input3 should not have run" + assert not result.did_run(mask2), "Mask2 should not have run" + assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Multiple OUTPUT_NODEs with dependencies + def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a chain of OUTPUT_NODEs + output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5) + output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0) + + # Create regular output nodes + save1 = g.node("SaveImage", images=output_node1.out(0)) + save2 = g.node("SaveImage", images=output_node2.out(0)) + save3 = g.node("SaveImage", images=input2.out(0)) + + # Run targeting only save2 + result = client.run(g, partial_execution_targets=[save2.id]) + + # Should run: input1, output_node1, output_node2, save2 + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)" + assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)" + assert result.was_executed(save2), "Save2 should have been executed" + + # Should NOT run: input2, save1, save3 + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(save1), "Save1 should not have run" + assert not result.did_run(save3), "Save3 should not have run" + + # Empty partial execution list (should execute nothing) + def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + _output1 = g.node("SaveImage", images=input1.out(0)) + + # Run with empty partial execution list + try: + _result = client.run(g, partial_execution_targets=[]) + # Should get an error because no outputs are selected + assert False, "Should have raised an error for empty partial execution list" + except urllib.error.HTTPError: + pass # Expected behavior + diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 657d49f2f..4f8f01ae4 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC): "expand": g.finalize(), } +class TestOutputNodeWithSocketOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def process(self, image, value): + # Apply value scaling and return both as output and socket + result = image * value + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestSamplingInExpansion": TestSamplingInExpansion, "TestSleep": TestSleep, "TestParallelSleep": TestParallelSleep, + "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestSamplingInExpansion": "Sampling In Expansion", "TestSleep": "Test Sleep", "TestParallelSleep": "Test Parallel Sleep", + "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", } From 97b8a2c26a335fe70ac6cfb44bf225454f51d700 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 31 Jul 2025 02:46:23 -0700 Subject: [PATCH 003/325] More accurate explanation of release process. (#9126) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index befc4c006..2abd8e600 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Release Process -ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories: +ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) From 4887743a2aef67e05909aeea61f6cdc93e269de3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 31 Jul 2025 15:02:12 -0700 Subject: [PATCH 004/325] V3 Node Schema Definition - initial (#8656) --- comfy_api/internal/__init__.py | 143 +++ comfy_api/latest/__init__.py | 18 + comfy_api/latest/_io.py | 1618 ++++++++++++++++++++++++++++++++ comfy_api/latest/_resources.py | 72 ++ comfy_api/latest/_ui.py | 457 +++++++++ comfy_api/v0_0_2/__init__.py | 2 + comfy_execution/graph.py | 23 +- comfy_execution/graph_utils.py | 16 + execution.py | 140 ++- nodes.py | 37 +- server.py | 3 + 11 files changed, 2475 insertions(+), 54 deletions(-) create mode 100644 comfy_api/latest/_io.py create mode 100644 comfy_api/latest/_resources.py create mode 100644 comfy_api/latest/_ui.py diff --git a/comfy_api/internal/__init__.py b/comfy_api/internal/__init__.py index c00b1fdbb..4ca02e320 100644 --- a/comfy_api/internal/__init__.py +++ b/comfy_api/internal/__init__.py @@ -5,3 +5,146 @@ from .api_registry import ( register_versions as register_versions, get_all_versions as get_all_versions, ) + +import asyncio +from dataclasses import asdict +from typing import Callable, Optional + + +def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]: + """Return the *callable* override of `name` visible on `cls`, or None if every + implementation up to (and including) `base` is the placeholder defined on `base`. + + If base is not provided, it will assume cls has a GET_BASE_CLASS + """ + if base is None: + if not hasattr(cls, "GET_BASE_CLASS"): + raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?") + base = cls.GET_BASE_CLASS() + base_attr = getattr(base, name, None) + if base_attr is None: + return None + base_func = base_attr.__func__ + for c in cls.mro(): # NodeB, NodeA, ComfyNode, object … + if c is base: # reached the placeholder – we're done + break + if name in c.__dict__: # first class that *defines* the attr + func = getattr(c, name).__func__ + if func is not base_func: # real override + return getattr(cls, name) # bound to *cls* + return None + + +class _ComfyNodeInternal: + """Class that all V3-based APIs inherit from for ComfyNode. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + @classmethod + def GET_NODE_INFO_V1(cls): + ... + + +class _NodeOutputInternal: + """Class that all V3-based APIs inherit from for NodeOutput. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + ... + + +def as_pruned_dict(dataclass_obj): + '''Return dict of dataclass object with pruned None values.''' + return prune_dict(asdict(dataclass_obj)) + +def prune_dict(d: dict): + return {k: v for k,v in d.items() if v is not None} + + +def is_class(obj): + ''' + Returns True if is a class type. + Returns False if is a class instance. + ''' + return isinstance(obj, type) + + +def copy_class(cls: type) -> type: + ''' + Copy a class and its attributes. + ''' + if cls is None: + return None + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls = type( + cls.__name__, + (cls,), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + return new_cls + + +class classproperty(object): + def __init__(self, f): + self.f = f + def __get__(self, obj, owner): + return self.f(owner) + + +# NOTE: this was ai generated and validated by hand +def shallow_clone_class(cls, new_name=None): + ''' + Shallow clone a class while preserving super() functionality. + ''' + new_name = new_name or f"{cls.__name__}Clone" + # Include the original class in the bases to maintain proper inheritance + new_bases = (cls,) + cls.__bases__ + return type(new_name, new_bases, dict(cls.__dict__)) + +# NOTE: this was ai generated and validated by hand +def lock_class(cls): + ''' + Lock a class so that its top-levelattributes cannot be modified. + ''' + # Locked instance __setattr__ + def locked_instance_setattr(self, name, value): + raise AttributeError( + f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}" + ) + # Locked metaclass + class LockedMeta(type(cls)): + def __setattr__(cls_, name, value): + raise AttributeError( + f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'" + ) + # Rebuild class with locked behavior + locked_dict = dict(cls.__dict__) + locked_dict['__setattr__'] = locked_instance_setattr + + return LockedMeta(cls.__name__, cls.__bases__, locked_dict) + + +def make_locked_method_func(type_obj, func, class_clone): + """ + Returns a function that, when called with **inputs, will execute: + getattr(type_obj, func).__func__(lock_class(class_clone), **inputs) + + Supports both synchronous and asynchronous methods. + """ + locked_class = lock_class(class_clone) + method = getattr(type_obj, func).__func__ + + # Check if the original method is async + if asyncio.iscoroutinefunction(method): + async def wrapped_async_func(**inputs): + return await method(locked_class, **inputs) + return wrapped_async_func + else: + def wrapped_func(**inputs): + return method(locked_class, **inputs) + return wrapped_func diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index e1f3a3655..2cee65aa9 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton @@ -7,6 +8,9 @@ from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents +from comfy_api.latest._io import _IO as io #noqa: F401 +from comfy_api.latest._ui import _UI as ui #noqa: F401 +# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image @@ -72,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase): execution: Execution +class ComfyExtension(ABC): + async def on_load(self) -> None: + """ + Called when an extension is loaded. + This should be used to initialize any global resources neeeded by the extension. + """ + + @abstractmethod + async def get_node_list(self) -> list[type[io.ComfyNode]]: + """ + Returns a list of nodes that this extension provides. + """ + class Input: Image = ImageInput Audio = AudioInput @@ -103,4 +120,5 @@ __all__ = [ "Input", "InputImpl", "Types", + "ComfyExtension", ] diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py new file mode 100644 index 000000000..ec1efb51d --- /dev/null +++ b/comfy_api/latest/_io.py @@ -0,0 +1,1618 @@ +from __future__ import annotations + +import copy +import inspect +from abc import ABC, abstractmethod +from collections import Counter +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING +from typing_extensions import NotRequired, final + +# used for type hinting +import torch + +if TYPE_CHECKING: + from spandrel import ImageModelDescriptor + from comfy.clip_vision import ClipVisionModel + from comfy.clip_vision import Output as ClipVisionOutput_ + from comfy.controlnet import ControlNet + from comfy.hooks import HookGroup, HookKeyframeGroup + from comfy.model_patcher import ModelPatcher + from comfy.samplers import CFGGuider, Sampler + from comfy.sd import CLIP, VAE + from comfy.sd import StyleModel as StyleModel_ + from comfy_api.input import VideoInput +from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, + prune_dict, shallow_clone_class) +from comfy_api.latest._resources import Resources, ResourcesLocal +from comfy_execution.graph_utils import ExecutionBlocker + +# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference + +class FolderType(str, Enum): + input = "input" + output = "output" + temp = "temp" + + +class UploadType(str, Enum): + image = "image_upload" + audio = "audio_upload" + video = "video_upload" + model = "file_upload" + + +class RemoteOptions: + def __init__(self, route: str, refresh_button: bool, control_after_refresh: Literal["first", "last"]="first", + timeout: int=None, max_retries: int=None, refresh: int=None): + self.route = route + """The route to the remote source.""" + self.refresh_button = refresh_button + """Specifies whether to show a refresh button in the UI below the widget.""" + self.control_after_refresh = control_after_refresh + """Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on.""" + self.timeout = timeout + """The maximum amount of time to wait for a response from the remote source in milliseconds.""" + self.max_retries = max_retries + """The maximum number of retries before aborting the request.""" + self.refresh = refresh + """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" + + def as_dict(self): + return prune_dict({ + "route": self.route, + "refresh_button": self.refresh_button, + "control_after_refresh": self.control_after_refresh, + "timeout": self.timeout, + "max_retries": self.max_retries, + "refresh": self.refresh, + }) + + +class NumberDisplay(str, Enum): + number = "number" + slider = "slider" + + +class _StringIOType(str): + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if not isinstance(value, str): + return True + a = frozenset(self.split(",")) + b = frozenset(value.split(",")) + return not (b.issubset(a) or a.issubset(b)) + +class _ComfyType(ABC): + Type = Any + io_type: str = None + +# NOTE: this is a workaround to make the decorator return the correct type +T = TypeVar("T", bound=type) +def comfytype(io_type: str, **kwargs): + ''' + Decorator to mark nested classes as ComfyType; io_type will be bound to the class. + + A ComfyType may have the following attributes: + - Type = + - class Input(Input): ... + - class Output(Output): ... + ''' + def decorator(cls: T) -> T: + if isinstance(cls, _ComfyType) or issubclass(cls, _ComfyType): + # clone Input and Output classes to avoid modifying the original class + new_cls = cls + if hasattr(new_cls, "Input"): + new_cls.Input = copy_class(new_cls.Input) + if hasattr(new_cls, "Output"): + new_cls.Output = copy_class(new_cls.Output) + else: + # copy class attributes except for special ones that shouldn't be in type() + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls: ComfyTypeIO = type( + cls.__name__, + (cls, ComfyTypeIO), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + # assign ComfyType attributes, if needed + # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) + new_cls.io_type = _StringIOType(io_type) + if hasattr(new_cls, "Input") and new_cls.Input is not None: + new_cls.Input.Parent = new_cls + if hasattr(new_cls, "Output") and new_cls.Output is not None: + new_cls.Output.Parent = new_cls + return new_cls + return decorator + +def Custom(io_type: str) -> type[ComfyTypeIO]: + '''Create a ComfyType for a custom io_type.''' + @comfytype(io_type=io_type) + class CustomComfyType(ComfyTypeIO): + ... + return CustomComfyType + +class _IO_V3: + ''' + Base class for V3 Inputs and Outputs. + ''' + Parent: _ComfyType = None + + def __init__(self): + pass + + @property + def io_type(self): + return self.Parent.io_type + + @property + def Type(self): + return self.Parent.Type + +class Input(_IO_V3): + ''' + Base class for a V3 Input. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__() + self.id = id + self.display_name = display_name + self.optional = optional + self.tooltip = tooltip + self.lazy = lazy + self.extra_dict = extra_dict if extra_dict is not None else {} + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "optional": self.optional, + "tooltip": self.tooltip, + "lazy": self.lazy, + }) | prune_dict(self.extra_dict) + + def get_io_type(self): + return _StringIOType(self.io_type) + +class WidgetInput(Input): + ''' + Base class for a V3 Input with widget. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: Any=None, + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.default = default + self.socketless = socketless + self.widget_type = widget_type + self.force_input = force_input + + def as_dict(self): + return super().as_dict() | prune_dict({ + "default": self.default, + "socketless": self.socketless, + "widgetType": self.widget_type, + "forceInput": self.force_input, + }) + + def get_io_type(self): + return self.widget_type if self.widget_type is not None else super().get_io_type() + + +class Output(_IO_V3): + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + self.id = id + self.display_name = display_name + self.tooltip = tooltip + self.is_output_list = is_output_list + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "tooltip": self.tooltip, + "is_output_list": self.is_output_list, + }) + + def get_io_type(self): + return self.io_type + + +class ComfyTypeI(_ComfyType): + '''ComfyType subclass that only has a default Input class - intended for types that only have Inputs.''' + class Input(Input): + ... + +class ComfyTypeIO(ComfyTypeI): + '''ComfyType subclass that has default Input and Output classes; useful for types with both Inputs and Outputs.''' + class Output(Output): + ... + + +@comfytype(io_type="BOOLEAN") +class Boolean(ComfyTypeIO): + Type = bool + + class Input(WidgetInput): + '''Boolean input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: bool=None, label_on: str=None, label_off: str=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.label_on = label_on + self.label_off = label_off + self.default: bool + + def as_dict(self): + return super().as_dict() | prune_dict({ + "label_on": self.label_on, + "label_off": self.label_off, + }) + +@comfytype(io_type="INT") +class Int(ComfyTypeIO): + Type = int + + class Input(WidgetInput): + '''Integer input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.control_after_generate = control_after_generate + self.display_mode = display_mode + self.default: int + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "control_after_generate": self.control_after_generate, + "display": self.display_mode.value if self.display_mode else None, + }) + +@comfytype(io_type="FLOAT") +class Float(ComfyTypeIO): + Type = float + + class Input(WidgetInput): + '''Float input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.round = round + self.display_mode = display_mode + self.default: float + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "round": self.round, + "display": self.display_mode, + }) + +@comfytype(io_type="STRING") +class String(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + '''String input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.multiline = multiline + self.placeholder = placeholder + self.dynamic_prompts = dynamic_prompts + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiline": self.multiline, + "placeholder": self.placeholder, + "dynamicPrompts": self.dynamic_prompts, + }) + +@comfytype(io_type="COMBO") +class Combo(ComfyTypeI): + Type = str + class Input(WidgetInput): + """Combo input (dropdown).""" + Type = str + def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: str=None, control_after_generate: bool=None, + upload: UploadType=None, image_folder: FolderType=None, + remote: RemoteOptions=None, + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + self.multiselect = False + self.options = options + self.control_after_generate = control_after_generate + self.upload = upload + self.image_folder = image_folder + self.remote = remote + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiselect": self.multiselect, + "options": self.options, + "control_after_generate": self.control_after_generate, + **({self.upload.value: True} if self.upload is not None else {}), + "image_folder": self.image_folder.value if self.image_folder else None, + "remote": self.remote.as_dict() if self.remote else None, + }) + + +@comfytype(io_type="COMBO") +class MultiCombo(ComfyTypeI): + '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' + # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect + Type = list[str] + class Input(Combo.Input): + def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, + socketless: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + self.multiselect = True + self.placeholder = placeholder + self.chip = chip + self.default: list[str] + + def as_dict(self): + to_return = super().as_dict() | prune_dict({ + "multi_select": self.multiselect, + "placeholder": self.placeholder, + "chip": self.chip, + }) + return to_return + +@comfytype(io_type="IMAGE") +class Image(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WAN_CAMERA_EMBEDDING") +class WanCameraEmbedding(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WEBCAM") +class Webcam(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + """Webcam input.""" + Type = str + def __init__( + self, id: str, display_name: str=None, optional=False, + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + ): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + + +@comfytype(io_type="MASK") +class Mask(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="LATENT") +class Latent(ComfyTypeIO): + '''Latents are stored as a dictionary.''' + class LatentDict(TypedDict): + samples: torch.Tensor + '''Latent tensors.''' + noise_mask: NotRequired[torch.Tensor] + batch_index: NotRequired[list[int]] + type: NotRequired[str] + '''Only needed if dealing with these types: audio, hunyuan3dv2''' + Type = LatentDict + +@comfytype(io_type="CONDITIONING") +class Conditioning(ComfyTypeIO): + class PooledDict(TypedDict): + pooled_output: torch.Tensor + '''Pooled output from CLIP.''' + control: NotRequired[ControlNet] + '''ControlNet to apply to conditioning.''' + control_apply_to_uncond: NotRequired[bool] + '''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.''' + cross_attn_controlnet: NotRequired[torch.Tensor] + '''CrossAttn from CLIP to use for controlnet only.''' + pooled_output_controlnet: NotRequired[torch.Tensor] + '''Pooled output from CLIP to use for controlnet only.''' + gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]] + '''GLIGEN to apply to conditioning.''' + area: NotRequired[tuple[int, ...] | tuple[str, float, ...]] + '''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates. + By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead. + + (1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels. + + ("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left + strength: NotRequired[float] + '''Strength of conditioning. Default strength is 1.0.''' + mask: NotRequired[torch.Tensor] + '''Mask to apply conditioning to.''' + mask_strength: NotRequired[float] + '''Strength of conditioning mask. Default strength is 1.0.''' + set_area_to_bounds: NotRequired[bool] + '''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.''' + concat_latent_image: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_mask: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_image: NotRequired[torch.Tensor] + '''Used by SD_4XUpscale_Conditioning.''' + noise_augmentation: NotRequired[float] + '''Used by SD_4XUpscale_Conditioning.''' + hooks: NotRequired[HookGroup] + '''Applies hooks to conditioning.''' + default: NotRequired[bool] + '''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.''' + start_percent: NotRequired[float] + '''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.''' + end_percent: NotRequired[float] + '''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.''' + clip_start_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.''' + clip_end_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.''' + attention_mask: NotRequired[torch.Tensor] + '''Masks text conditioning; used by StyleModel among others.''' + attention_mask_img_shape: NotRequired[tuple[int, ...]] + '''Masks text conditioning; used by StyleModel among others.''' + unclip_conditioning: NotRequired[list[dict]] + '''Used by unCLIP.''' + conditioning_lyrics: NotRequired[torch.Tensor] + '''Used by AceT5Model.''' + seconds_start: NotRequired[float] + '''Used by StableAudio.''' + seconds_total: NotRequired[float] + '''Used by StableAudio.''' + lyrics_strength: NotRequired[float] + '''Used by AceStepAudio.''' + width: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + height: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + aesthetic_score: NotRequired[float] + '''Used by CLIPTextEncodeSDXL/Refiner.''' + crop_w: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + crop_h: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_width: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_height: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + reference_latents: NotRequired[list[torch.Tensor]] + '''Used by ReferenceLatent.''' + guidance: NotRequired[float] + '''Used by Flux-like models with guidance embed.''' + guiding_frame_index: NotRequired[int] + '''Used by Hunyuan ImageToVideo.''' + ref_latent: NotRequired[torch.Tensor] + '''Used by Hunyuan ImageToVideo.''' + keyframe_idxs: NotRequired[list[int]] + '''Used by LTXV.''' + frame_rate: NotRequired[float] + '''Used by LTXV.''' + stable_cascade_prior: NotRequired[torch.Tensor] + '''Used by StableCascade.''' + elevation: NotRequired[list[float]] + '''Used by SV3D.''' + azimuth: NotRequired[list[float]] + '''Used by SV3D.''' + motion_bucket_id: NotRequired[int] + '''Used by SVD-like models.''' + fps: NotRequired[int] + '''Used by SVD-like models.''' + augmentation_level: NotRequired[float] + '''Used by SVD-like models.''' + clip_vision_output: NotRequired[ClipVisionOutput_] + '''Used by WAN-like models.''' + vace_frames: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_mask: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_strength: NotRequired[float] + '''Used by WAN VACE.''' + camera_conditions: NotRequired[Any] # TODO: assign proper type once defined + '''Used by WAN Camera.''' + time_dim_concat: NotRequired[torch.Tensor] + '''Used by WAN Phantom Subject.''' + + CondList = list[tuple[torch.Tensor, PooledDict]] + Type = CondList + +@comfytype(io_type="SAMPLER") +class Sampler(ComfyTypeIO): + if TYPE_CHECKING: + Type = Sampler + +@comfytype(io_type="SIGMAS") +class Sigmas(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="NOISE") +class Noise(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="GUIDER") +class Guider(ComfyTypeIO): + if TYPE_CHECKING: + Type = CFGGuider + +@comfytype(io_type="CLIP") +class Clip(ComfyTypeIO): + if TYPE_CHECKING: + Type = CLIP + +@comfytype(io_type="CONTROL_NET") +class ControlNet(ComfyTypeIO): + if TYPE_CHECKING: + Type = ControlNet + +@comfytype(io_type="VAE") +class Vae(ComfyTypeIO): + if TYPE_CHECKING: + Type = VAE + +@comfytype(io_type="MODEL") +class Model(ComfyTypeIO): + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="CLIP_VISION") +class ClipVision(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionModel + +@comfytype(io_type="CLIP_VISION_OUTPUT") +class ClipVisionOutput(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionOutput_ + +@comfytype(io_type="STYLE_MODEL") +class StyleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = StyleModel_ + +@comfytype(io_type="GLIGEN") +class Gligen(ComfyTypeIO): + '''ModelPatcher that wraps around a 'Gligen' model.''' + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="UPSCALE_MODEL") +class UpscaleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = ImageModelDescriptor + +@comfytype(io_type="AUDIO") +class Audio(ComfyTypeIO): + class AudioDict(TypedDict): + waveform: torch.Tensor + sampler_rate: int + Type = AudioDict + +@comfytype(io_type="VIDEO") +class Video(ComfyTypeIO): + if TYPE_CHECKING: + Type = VideoInput + +@comfytype(io_type="SVG") +class SVG(ComfyTypeIO): + Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="LORA_MODEL") +class LoraModel(ComfyTypeIO): + Type = dict[str, torch.Tensor] + +@comfytype(io_type="LOSS_MAP") +class LossMap(ComfyTypeIO): + class LossMapDict(TypedDict): + loss: list[torch.Tensor] + Type = LossMapDict + +@comfytype(io_type="VOXEL") +class Voxel(ComfyTypeIO): + Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="MESH") +class Mesh(ComfyTypeIO): + Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="HOOKS") +class Hooks(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookGroup + +@comfytype(io_type="HOOK_KEYFRAMES") +class HookKeyframes(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookKeyframeGroup + +@comfytype(io_type="TIMESTEPS_RANGE") +class TimestepsRange(ComfyTypeIO): + '''Range defined by start and endpoint, between 0.0 and 1.0.''' + Type = tuple[int, int] + +@comfytype(io_type="LATENT_OPERATION") +class LatentOperation(ComfyTypeIO): + Type = Callable[[torch.Tensor], torch.Tensor] + +@comfytype(io_type="FLOW_CONTROL") +class FlowControl(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + Type = tuple[str, Any] + +@comfytype(io_type="ACCUMULATION") +class Accumulation(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + class AccumulationDict(TypedDict): + accum: list[Any] + Type = AccumulationDict + + +@comfytype(io_type="LOAD3D_CAMERA") +class Load3DCamera(ComfyTypeIO): + class CameraInfo(TypedDict): + position: dict[str, float | int] + target: dict[str, float | int] + zoom: int + cameraType: str + + Type = CameraInfo + + +@comfytype(io_type="LOAD_3D") +class Load3D(ComfyTypeIO): + """3D models are stored as a dictionary.""" + class Model3DDict(TypedDict): + image: str + mask: str + normal: str + camera_info: Load3DCamera.CameraInfo + recording: NotRequired[str] + + Type = Model3DDict + + +@comfytype(io_type="LOAD_3D_ANIMATION") +class Load3DAnimation(Load3D): + ... + + +@comfytype(io_type="PHOTOMAKER") +class Photomaker(ComfyTypeIO): + Type = Any + + +@comfytype(io_type="POINT") +class Point(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="FACE_ANALYSIS") +class FaceAnalysis(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="BBOX") +class BBOX(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="SEGS") +class SEGS(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="*") +class AnyType(ComfyTypeIO): + Type = Any + +@comfytype(io_type="COMFY_MULTITYPED_V3") +class MultiType: + Type = Any + class Input(Input): + ''' + Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. + ''' + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + # if id is an Input, then use that Input with overridden values + self.input_override = None + if isinstance(id, Input): + self.input_override = copy.copy(id) + optional = id.optional if id.optional is True else optional + tooltip = id.tooltip if id.tooltip is not None else tooltip + display_name = id.display_name if id.display_name is not None else display_name + lazy = id.lazy if id.lazy is not None else lazy + id = id.id + # if is a widget input, make sure widget_type is set appropriately + if isinstance(self.input_override, WidgetInput): + self.input_override.widget_type = self.input_override.get_io_type() + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self._io_types = types + + @property + def io_types(self) -> list[type[Input]]: + ''' + Returns list of Input class types permitted. + ''' + io_types = [] + for x in self._io_types: + if not is_class(x): + io_types.append(type(x)) + else: + io_types.append(x) + return io_types + + def get_io_type(self): + # ensure types are unique and order is preserved + str_types = [x.io_type for x in self.io_types] + if self.input_override is not None: + str_types.insert(0, self.input_override.get_io_type()) + return ",".join(list(dict.fromkeys(str_types))) + + def as_dict(self): + if self.input_override is not None: + return self.input_override.as_dict() | super().as_dict() + else: + return super().as_dict() + +class DynamicInput(Input, ABC): + ''' + Abstract class for dynamic input registration. + ''' + @abstractmethod + def get_dynamic(self) -> list[Input]: + ... + +class DynamicOutput(Output, ABC): + ''' + Abstract class for dynamic output registration. + ''' + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + + @abstractmethod + def get_dynamic(self) -> list[Output]: + ... + + +@comfytype(io_type="COMFY_AUTOGROW_V3") +class AutogrowDynamic(ComfyTypeI): + Type = list[Any] + class Input(DynamicInput): + def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template_input = template_input + if min is not None: + assert(min >= 1) + if max is not None: + assert(max >= 1) + self.min = min + self.max = max + + def get_dynamic(self) -> list[Input]: + curr_count = 1 + new_inputs = [] + for i in range(self.min): + new_input = copy.copy(self.template_input) + new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" + if new_input.display_name is not None: + new_input.display_name = f"{new_input.display_name}{curr_count}" + new_input.optional = self.optional or new_input.optional + if isinstance(self.template_input, WidgetInput): + new_input.force_input = True + new_inputs.append(new_input) + curr_count += 1 + # pretend to expand up to max + for i in range(curr_count-1, self.max): + new_input = copy.copy(self.template_input) + new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" + if new_input.display_name is not None: + new_input.display_name = f"{new_input.display_name}{curr_count}" + new_input.optional = True + if isinstance(self.template_input, WidgetInput): + new_input.force_input = True + new_inputs.append(new_input) + curr_count += 1 + return new_inputs + +@comfytype(io_type="COMFY_COMBODYNAMIC_V3") +class ComboDynamic(ComfyTypeI): + class Input(DynamicInput): + def __init__(self, id: str): + pass + +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): + self.template_id = template_id + self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": "".join(t.io_type for t in self.allowed_types), + } + + class Input(DynamicInput): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def get_dynamic(self) -> list[Input]: + return [self] + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(DynamicOutput): + def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def get_dynamic(self) -> list[Output]: + return [self] + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + +class HiddenHolder: + def __init__(self, unique_id: str, prompt: Any, + extra_pnginfo: Any, dynprompt: Any, + auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + self.unique_id = unique_id + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + self.prompt = prompt + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + self.extra_pnginfo = extra_pnginfo + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + self.dynprompt = dynprompt + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + self.auth_token_comfy_org = auth_token_comfy_org + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + self.api_key_comfy_org = api_key_comfy_org + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + def __getattr__(self, key: str): + '''If hidden variable not found, return None.''' + return None + + @classmethod + def from_dict(cls, d: dict | None): + if d is None: + d = {} + return cls( + unique_id=d.get(Hidden.unique_id, None), + prompt=d.get(Hidden.prompt, None), + extra_pnginfo=d.get(Hidden.extra_pnginfo, None), + dynprompt=d.get(Hidden.dynprompt, None), + auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), + api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + ) + +class Hidden(str, Enum): + ''' + Enumerator for requesting hidden variables in nodes. + ''' + unique_id = "UNIQUE_ID" + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + prompt = "PROMPT" + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + extra_pnginfo = "EXTRA_PNGINFO" + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + dynprompt = "DYNPROMPT" + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG" + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + api_key_comfy_org = "API_KEY_COMFY_ORG" + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + +@dataclass +class NodeInfoV1: + input: dict=None + input_order: dict[str, list[str]]=None + output: list[str]=None + output_is_list: list[bool]=None + output_name: list[str]=None + output_tooltips: list[str]=None + name: str=None + display_name: str=None + description: str=None + python_module: Any=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + +@dataclass +class NodeInfoV3: + input: dict=None + output: dict=None + hidden: list[str]=None + name: str=None + display_name: str=None + description: str=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + + +@dataclass +class Schema: + """Definition of V3 node properties.""" + + node_id: str + """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes.""" + display_name: str = None + """Display name of node.""" + category: str = "sd" + """The category of the node, as per the "Add Node" menu.""" + inputs: list[Input]=None + outputs: list[Output]=None + hidden: list[Hidden]=None + description: str="" + """Node description, shown as a tooltip when hovering over the node.""" + is_input_list: bool = False + """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. + + All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. + + From the docs: + + A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing + """ + is_output_node: bool=False + """Flags this node as an output node, causing any inputs it requires to be executed. + + If a node is not connected to any output nodes, that node will not be executed. Usage:: + + From the docs: + + By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node + """ + is_deprecated: bool=False + """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + is_experimental: bool=False + """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_api_node: bool=False + """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + not_idempotent: bool=False + """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" + enable_expand: bool=False + """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + + def validate(self): + '''Validate the schema: + - verify ids on inputs and outputs are unique - both internally and in relation to each other + ''' + input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_set = set(input_ids) + output_set = set(output_ids) + issues = [] + # verify ids are unique per list + if len(input_set) != len(input_ids): + issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") + if len(output_set) != len(output_ids): + issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") + # verify ids are unique between lists + intersection = input_set & output_set + if len(intersection) > 0: + issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") + if len(issues) > 0: + raise ValueError("\n".join(issues)) + + def finalize(self): + """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # if is an api_node, will need key-related hidden + if self.is_api_node: + if self.hidden is None: + self.hidden = [] + if Hidden.auth_token_comfy_org not in self.hidden: + self.hidden.append(Hidden.auth_token_comfy_org) + if Hidden.api_key_comfy_org not in self.hidden: + self.hidden.append(Hidden.api_key_comfy_org) + # if is an output_node, will need prompt and extra_pnginfo + if self.is_output_node: + if self.hidden is None: + self.hidden = [] + if Hidden.prompt not in self.hidden: + self.hidden.append(Hidden.prompt) + if Hidden.extra_pnginfo not in self.hidden: + self.hidden.append(Hidden.extra_pnginfo) + # give outputs without ids default ids + if self.outputs is not None: + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" + + def get_v1_info(self, cls) -> NodeInfoV1: + # get V1 inputs + input = { + "required": {} + } + if self.inputs: + for i in self.inputs: + if isinstance(i, DynamicInput): + dynamic_inputs = i.get_dynamic() + for d in dynamic_inputs: + add_to_dict_v1(d, input) + else: + add_to_dict_v1(i, input) + if self.hidden: + for hidden in self.hidden: + input.setdefault("hidden", {})[hidden.name] = (hidden.value,) + # create separate lists from output fields + output = [] + output_is_list = [] + output_name = [] + output_tooltips = [] + if self.outputs: + for o in self.outputs: + output.append(o.io_type) + output_is_list.append(o.is_output_list) + output_name.append(o.display_name if o.display_name else o.io_type) + output_tooltips.append(o.tooltip if o.tooltip else None) + + info = NodeInfoV1( + input=input, + input_order={key: list(value.keys()) for (key, value) in input.items()}, + output=output, + output_is_list=output_is_list, + output_name=output_name, + output_tooltips=output_tooltips, + name=self.node_id, + display_name=self.display_name, + category=self.category, + description=self.description, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + + def get_v3_info(self, cls) -> NodeInfoV3: + input_dict = {} + output_dict = {} + hidden_list = [] + # TODO: make sure dynamic types will be handled correctly + if self.inputs: + for input in self.inputs: + add_to_dict_v3(input, input_dict) + if self.outputs: + for output in self.outputs: + add_to_dict_v3(output, output_dict) + if self.hidden: + for hidden in self.hidden: + hidden_list.append(hidden.value) + + info = NodeInfoV3( + input=input_dict, + output=output_dict, + hidden=hidden_list, + name=self.node_id, + display_name=self.display_name, + description=self.description, + category=self.category, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + +def add_to_dict_v1(i: Input, input: dict): + key = "optional" if i.optional else "required" + as_dict = i.as_dict() + # for v1, we don't want to include the optional key + as_dict.pop("optional", None) + input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + +def add_to_dict_v3(io: Input | Output, d: dict): + d[io.id] = (io.get_io_type(), io.as_dict()) + + + +class _ComfyNodeBaseInternal(_ComfyNodeInternal): + """Common base class for storing internal methods and properties; DO NOT USE for defining nodes.""" + + RELATIVE_PYTHON_MODULE = None + SCHEMA = None + + # filled in during execution + resources: Resources = None + hidden: HiddenHolder = None + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + def __init__(self): + self.local_resources: ResourcesLocal = None + self.__class__.VALIDATE_CLASS() + + @classmethod + def GET_BASE_CLASS(cls): + return _ComfyNodeBaseInternal + + @final + @classmethod + def VALIDATE_CLASS(cls): + if first_real_override(cls, "define_schema") is None: + raise Exception(f"No define_schema function was defined for node class {cls.__name__}.") + if first_real_override(cls, "execute") is None: + raise Exception(f"No execute function was defined for node class {cls.__name__}.") + + @classproperty + def FUNCTION(cls): # noqa + if inspect.iscoroutinefunction(cls.execute): + return "EXECUTE_NORMALIZED_ASYNC" + return "EXECUTE_NORMALIZED" + + @final + @classmethod + def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput: + to_return = cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput: + to_return = await cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + """Creates clone of real node class to prevent monkey-patching.""" + c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) + type_clone: type[ComfyNode] = shallow_clone_class(c_type) + # set hidden + type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + return type_clone + + @final + @classmethod + def GET_NODE_INFO_V3(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v3_info(cls) + return asdict(info) + ############################################# + # V1 Backwards Compatibility code + #-------------------------------------------- + @final + @classmethod + def GET_NODE_INFO_V1(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v1_info(cls) + return asdict(info) + + _DESCRIPTION = None + @final + @classproperty + def DESCRIPTION(cls): # noqa + if cls._DESCRIPTION is None: + cls.GET_SCHEMA() + return cls._DESCRIPTION + + _CATEGORY = None + @final + @classproperty + def CATEGORY(cls): # noqa + if cls._CATEGORY is None: + cls.GET_SCHEMA() + return cls._CATEGORY + + _EXPERIMENTAL = None + @final + @classproperty + def EXPERIMENTAL(cls): # noqa + if cls._EXPERIMENTAL is None: + cls.GET_SCHEMA() + return cls._EXPERIMENTAL + + _DEPRECATED = None + @final + @classproperty + def DEPRECATED(cls): # noqa + if cls._DEPRECATED is None: + cls.GET_SCHEMA() + return cls._DEPRECATED + + _API_NODE = None + @final + @classproperty + def API_NODE(cls): # noqa + if cls._API_NODE is None: + cls.GET_SCHEMA() + return cls._API_NODE + + _OUTPUT_NODE = None + @final + @classproperty + def OUTPUT_NODE(cls): # noqa + if cls._OUTPUT_NODE is None: + cls.GET_SCHEMA() + return cls._OUTPUT_NODE + + _INPUT_IS_LIST = None + @final + @classproperty + def INPUT_IS_LIST(cls): # noqa + if cls._INPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._INPUT_IS_LIST + _OUTPUT_IS_LIST = None + + @final + @classproperty + def OUTPUT_IS_LIST(cls): # noqa + if cls._OUTPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._OUTPUT_IS_LIST + + _RETURN_TYPES = None + @final + @classproperty + def RETURN_TYPES(cls): # noqa + if cls._RETURN_TYPES is None: + cls.GET_SCHEMA() + return cls._RETURN_TYPES + + _RETURN_NAMES = None + @final + @classproperty + def RETURN_NAMES(cls): # noqa + if cls._RETURN_NAMES is None: + cls.GET_SCHEMA() + return cls._RETURN_NAMES + + _OUTPUT_TOOLTIPS = None + @final + @classproperty + def OUTPUT_TOOLTIPS(cls): # noqa + if cls._OUTPUT_TOOLTIPS is None: + cls.GET_SCHEMA() + return cls._OUTPUT_TOOLTIPS + + _NOT_IDEMPOTENT = None + @final + @classproperty + def NOT_IDEMPOTENT(cls): # noqa + if cls._NOT_IDEMPOTENT is None: + cls.GET_SCHEMA() + return cls._NOT_IDEMPOTENT + + @final + @classmethod + def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + schema = cls.FINALIZE_SCHEMA() + info = schema.get_v1_info(cls) + input = info.input + if not include_hidden: + input.pop("hidden", None) + if return_schema: + return input, schema + return input + + @final + @classmethod + def FINALIZE_SCHEMA(cls): + """Call define_schema and finalize it.""" + schema = cls.define_schema() + schema.finalize() + return schema + + @final + @classmethod + def GET_SCHEMA(cls) -> Schema: + """Validate node class, finalize schema, validate schema, and set expected class properties.""" + cls.VALIDATE_CLASS() + schema = cls.FINALIZE_SCHEMA() + schema.validate() + if cls._DESCRIPTION is None: + cls._DESCRIPTION = schema.description + if cls._CATEGORY is None: + cls._CATEGORY = schema.category + if cls._EXPERIMENTAL is None: + cls._EXPERIMENTAL = schema.is_experimental + if cls._DEPRECATED is None: + cls._DEPRECATED = schema.is_deprecated + if cls._API_NODE is None: + cls._API_NODE = schema.is_api_node + if cls._OUTPUT_NODE is None: + cls._OUTPUT_NODE = schema.is_output_node + if cls._INPUT_IS_LIST is None: + cls._INPUT_IS_LIST = schema.is_input_list + if cls._NOT_IDEMPOTENT is None: + cls._NOT_IDEMPOTENT = schema.not_idempotent + + if cls._RETURN_TYPES is None: + output = [] + output_name = [] + output_is_list = [] + output_tooltips = [] + if schema.outputs: + for o in schema.outputs: + output.append(o.io_type) + output_name.append(o.display_name if o.display_name else o.io_type) + output_is_list.append(o.is_output_list) + output_tooltips.append(o.tooltip if o.tooltip else None) + + cls._RETURN_TYPES = output + cls._RETURN_NAMES = output_name + cls._OUTPUT_IS_LIST = output_is_list + cls._OUTPUT_TOOLTIPS = output_tooltips + cls.SCHEMA = schema + return schema + #-------------------------------------------- + ############################################# + + +class ComfyNode(_ComfyNodeBaseInternal): + """Common base class for all V3 nodes.""" + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + @final + @classmethod + def GET_BASE_CLASS(cls): + """DO NOT override this class. Will break things in execution.py.""" + return ComfyNode + + +class NodeOutput(_NodeOutputInternal): + ''' + Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg. + ''' + def __init__(self, *args: Any, ui: _UIOutput | dict=None, expand: dict=None, block_execution: str=None): + self.args = args + self.ui = ui + self.expand = expand + self.block_execution = block_execution + + @property + def result(self): + return self.args if len(self.args) > 0 else None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + args = () + ui = None + expand = None + if "result" in data: + result = data["result"] + if isinstance(result, ExecutionBlocker): + return cls(block_execution=result.message) + args = result + if "ui" in data: + ui = data["ui"] + if "expand" in data: + expand = data["expand"] + return cls(args=args, ui=ui, expand=expand) + + def __getitem__(self, index) -> Any: + return self.args[index] + +class _UIOutput(ABC): + def __init__(self): + pass + + @abstractmethod + def as_dict(self) -> dict: + ... + + +class _IO: + FolderType = FolderType + UploadType = UploadType + RemoteOptions = RemoteOptions + NumberDisplay = NumberDisplay + + comfytype = staticmethod(comfytype) + Custom = staticmethod(Custom) + Input = Input + WidgetInput = WidgetInput + Output = Output + ComfyTypeI = ComfyTypeI + ComfyTypeIO = ComfyTypeIO + #--------------------------------- + # Supported Types + Boolean = Boolean + Int = Int + Float = Float + String = String + Combo = Combo + MultiCombo = MultiCombo + Image = Image + WanCameraEmbedding = WanCameraEmbedding + Webcam = Webcam + Mask = Mask + Latent = Latent + Conditioning = Conditioning + Sampler = Sampler + Sigmas = Sigmas + Noise = Noise + Guider = Guider + Clip = Clip + ControlNet = ControlNet + Vae = Vae + Model = Model + ClipVision = ClipVision + ClipVisionOutput = ClipVisionOutput + StyleModel = StyleModel + Gligen = Gligen + UpscaleModel = UpscaleModel + Audio = Audio + Video = Video + SVG = SVG + LoraModel = LoraModel + LossMap = LossMap + Voxel = Voxel + Mesh = Mesh + Hooks = Hooks + HookKeyframes = HookKeyframes + TimestepsRange = TimestepsRange + LatentOperation = LatentOperation + FlowControl = FlowControl + Accumulation = Accumulation + Load3DCamera = Load3DCamera + Load3D = Load3D + Load3DAnimation = Load3DAnimation + Photomaker = Photomaker + Point = Point + FaceAnalysis = FaceAnalysis + BBOX = BBOX + SEGS = SEGS + AnyType = AnyType + MultiType = MultiType + #--------------------------------- + HiddenHolder = HiddenHolder + Hidden = Hidden + NodeInfoV1 = NodeInfoV1 + NodeInfoV3 = NodeInfoV3 + Schema = Schema + ComfyNode = ComfyNode + NodeOutput = NodeOutput + add_to_dict_v1 = staticmethod(add_to_dict_v1) + add_to_dict_v3 = staticmethod(add_to_dict_v3) diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py new file mode 100644 index 000000000..a6bdda972 --- /dev/null +++ b/comfy_api/latest/_resources.py @@ -0,0 +1,72 @@ +from __future__ import annotations +import comfy.utils +import folder_paths +import logging +from abc import ABC, abstractmethod +from typing import Any +import torch + +class ResourceKey(ABC): + Type = Any + def __init__(self): + ... + +class TorchDictFolderFilename(ResourceKey): + '''Key for requesting a torch file via file_name from a folder category.''' + Type = dict[str, torch.Tensor] + def __init__(self, folder_name: str, file_name: str): + self.folder_name = folder_name + self.file_name = file_name + + def __hash__(self): + return hash((self.folder_name, self.file_name)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TorchDictFolderFilename): + return False + return self.folder_name == other.folder_name and self.file_name == other.file_name + + def __str__(self): + return f"{self.folder_name} -> {self.file_name}" + +class Resources(ABC): + def __init__(self): + ... + + @abstractmethod + def get(self, key: ResourceKey, default: Any=...) -> Any: + pass + +class ResourcesLocal(Resources): + def __init__(self): + super().__init__() + self.local_resources: dict[ResourceKey, Any] = {} + + def get(self, key: ResourceKey, default: Any=...) -> Any: + cached = self.local_resources.get(key, None) + if cached is not None: + logging.info(f"Using cached resource '{key}'") + return cached + logging.info(f"Loading resource '{key}'") + to_return = None + if isinstance(key, TorchDictFolderFilename): + if default is ...: + to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) + else: + full_path = folder_paths.get_full_path(key.folder_name, key.file_name) + if full_path is not None: + to_return = comfy.utils.load_torch_file(full_path, safe_load=True) + + if to_return is not None: + self.local_resources[key] = to_return + return to_return + if default is not ...: + return default + raise Exception(f"Unsupported resource key type: {type(key)}") + + +class _RESOURCES: + ResourceKey = ResourceKey + TorchDictFolderFilename = TorchDictFolderFilename + Resources = Resources + ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py new file mode 100644 index 000000000..6b8a39d58 --- /dev/null +++ b/comfy_api/latest/_ui.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import json +import os +import random +from io import BytesIO +from typing import Type + +import av +import numpy as np +import torch +import torchaudio +from PIL import Image as PILImage +from PIL.PngImagePlugin import PngInfo + +import folder_paths + +# used for image preview +from comfy.cli_args import args +from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput + + +class SavedResult(dict): + def __init__(self, filename: str, subfolder: str, type: FolderType): + super().__init__(filename=filename, subfolder=subfolder,type=type.value) + + @property + def filename(self) -> str: + return self["filename"] + + @property + def subfolder(self) -> str: + return self["subfolder"] + + @property + def type(self) -> FolderType: + return FolderType(self["type"]) + + +class SavedImages(_UIOutput): + """A UI output class to represent one or more saved images, potentially animated.""" + def __init__(self, results: list[SavedResult], is_animated: bool = False): + super().__init__() + self.results = results + self.is_animated = is_animated + + def as_dict(self) -> dict: + data = {"images": self.results} + if self.is_animated: + data["animated"] = (True,) + return data + + +class SavedAudios(_UIOutput): + """UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus).""" + def __init__(self, results: list[SavedResult]): + super().__init__() + self.results = results + + def as_dict(self) -> dict: + return {"audio": self.results} + + +def _get_directory_by_folder_type(folder_type: FolderType) -> str: + if folder_type == FolderType.input: + return folder_paths.get_input_directory() + if folder_type == FolderType.output: + return folder_paths.get_output_directory() + return folder_paths.get_temp_directory() + + +class ImageSaveHelper: + """A helper class with static methods to handle image saving and metadata.""" + + @staticmethod + def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image: + """Converts a single torch tensor to a PIL Image.""" + return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) + + @staticmethod + def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo.""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add_text("prompt", json.dumps(cls.hidden.prompt)) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x])) + return metadata + + @staticmethod + def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add( + b"comf", + "prompt".encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.prompt).encode("latin-1", "strict"), + after_idat=True, + ) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add( + b"comf", + x.encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"), + after_idat=True, + ) + return metadata + + @staticmethod + def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: + """Creates EXIF metadata bytes for WebP images.""" + exif_data = pil_image.getexif() + if args.disable_metadata or cls is None or cls.hidden is None: + return exif_data + if cls.hidden.prompt is not None: + exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model + if cls.hidden.extra_pnginfo is not None: + inital_exif_tag = 0x010F # EXIF 0x010f = Make + for key, value in cls.hidden.extra_pnginfo.items(): + exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value)) + inital_exif_tag -= 1 + return exif_data + + @staticmethod + def save_images( + images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, + ) -> list[SavedResult]: + """Saves a batch of images as individual PNG files.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + results = [] + metadata = ImageSaveHelper._create_png_metadata(cls) + for batch_number, image_tensor in enumerate(images): + img = ImageSaveHelper._convert_tensor_to_pil(image_tensor) + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level) + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + return results + + @staticmethod + def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: + """Saves a batch of images and returns a UI object for the node output.""" + return SavedImages( + ImageSaveHelper.save_images( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + compress_level=compress_level, + ) + ) + + @staticmethod + def save_animated_png( + images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedResult: + """Saves a batch of images as a single animated PNG.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + metadata = ImageSaveHelper._create_animated_png_metadata(cls) + file = f"{filename}_{counter:05}_.png" + save_path = os.path.join(full_output_folder, file) + pil_images[0].save( + save_path, + pnginfo=metadata, + compress_level=compress_level, + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_png_ui( + images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedImages: + """Saves an animated PNG and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_png( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + return SavedImages([result], is_animated=len(images) > 1) + + @staticmethod + def save_animated_webp( + images, + filename_prefix: str, + folder_type: FolderType, + cls: Type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedResult: + """Saves a batch of images as a single animated WebP.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls) + file = f"{filename}_{counter:05}_.webp" + pil_images[0].save( + os.path.join(full_output_folder, file), + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + exif=pil_exif, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_webp_ui( + images, + filename_prefix: str, + cls: Type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedImages: + """Saves an animated WebP and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_webp( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedImages([result], is_animated=len(images) > 1) + + +class AudioSaveHelper: + """A helper class with static methods to handle audio saving and metadata.""" + _OPUS_RATES = [8000, 12000, 16000, 24000, 48000] + + @staticmethod + def save_audio( + audio: dict, + filename_prefix: str, + folder_type: FolderType, + cls: Type[ComfyNode] | None, + format: str = "flac", + quality: str = "128k", + ) -> list[SavedResult]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type) + ) + + metadata = {} + if not args.disable_metadata and cls is not None: + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + + results = [] + for batch_number, waveform in enumerate(audio["waveform"].cpu()): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.{format}" + output_path = os.path.join(full_output_folder, file) + + # Use original sample rate initially + sample_rate = audio["sample_rate"] + + # Handle Opus sample rate requirements + if format == "opus": + if sample_rate > 48000: + sample_rate = 48000 + elif sample_rate not in AudioSaveHelper._OPUS_RATES: + # Find the next highest supported rate + for rate in sorted(AudioSaveHelper._OPUS_RATES): + if rate > sample_rate: + sample_rate = rate + break + if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported + sample_rate = 48000 + + # Resample if necessary + if sample_rate != audio["sample_rate"]: + waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) + + # Create output with specified format + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format=format) + + # Set metadata on the container + for key, value in metadata.items(): + output_container.metadata[key] = value + + # Set up the output stream with appropriate properties + if format == "opus": + out_stream = output_container.add_stream("libopus", rate=sample_rate) + if quality == "64k": + out_stream.bit_rate = 64000 + elif quality == "96k": + out_stream.bit_rate = 96000 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "192k": + out_stream.bit_rate = 192000 + elif quality == "320k": + out_stream.bit_rate = 320000 + elif format == "mp3": + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + if quality == "V0": + # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool + out_stream.codec_context.qscale = 1 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "320k": + out_stream.bit_rate = 320000 + else: # format == "flac": + out_stream = output_container.add_stream("flac", rate=sample_rate) + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + + # Flush encoder + output_container.mux(out_stream.encode(None)) + + # Close containers + output_container.close() + + # Write the output to file + output_buffer.seek(0) + with open(output_path, "wb") as f: + f.write(output_buffer.getbuffer()) + + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + + return results + + @staticmethod + def get_save_audio_ui( + audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", + ) -> SavedAudios: + """Save and instantly wrap for UI.""" + return SavedAudios( + AudioSaveHelper.save_audio( + audio, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + format=format, + quality=quality, + ) + ) + + +class PreviewImage(_UIOutput): + def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): + self.values = ImageSaveHelper.save_images( + image, + filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + compress_level=1, + ) + self.animated = animated + + def as_dict(self): + return { + "images": self.values, + "animated": (self.animated,) + } + + +class PreviewMask(PreviewImage): + def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs): + preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + super().__init__(preview, animated, cls, **kwargs) + + +class PreviewAudio(_UIOutput): + def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): + self.values = AudioSaveHelper.save_audio( + audio, + filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + format="flac", + quality="128k", + ) + + def as_dict(self) -> dict: + return {"audio": self.values} + + +class PreviewVideo(_UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + return {"images": self.values, "animated": (True,)} + + +class PreviewUI3D(_UIOutput): + def __init__(self, model_file, camera_info, **kwargs): + self.model_file = model_file + self.camera_info = camera_info + + def as_dict(self): + return {"result": [self.model_file, self.camera_info]} + + +class PreviewText(_UIOutput): + def __init__(self, value: str, **kwargs): + self.value = value + + def as_dict(self): + return {"text": (self.value,)} + + +class _UI: + SavedResult = SavedResult + SavedImages = SavedImages + SavedAudios = SavedAudios + ImageSaveHelper = ImageSaveHelper + AudioSaveHelper = AudioSaveHelper + PreviewImage = PreviewImage + PreviewMask = PreviewMask + PreviewAudio = PreviewAudio + PreviewVideo = PreviewVideo + PreviewUI3D = PreviewUI3D + PreviewText = PreviewText diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index ea83833fb..de0f95001 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,6 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class +from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -40,4 +41,5 @@ __all__ = [ "Input", "InputImpl", "Types", + "ComfyExtension", ] diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 60e2ab91e..f4b427265 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -4,9 +4,12 @@ from typing import Type, Literal import nodes import asyncio import inspect -from comfy_execution.graph_utils import is_link +from comfy_execution.graph_utils import is_link, ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions +# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests +ExecutionBlocker = ExecutionBlocker + class DependencyCycleError(Exception): pass @@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort): del blocked_by[node_id] to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] return list(blocked_by.keys()) - -class ExecutionBlocker: - """ - Return this from a node and any users will be blocked with the given error message. - If the message is None, execution will be blocked silently instead. - Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's - possible, a lazy input will be more efficient and have a better user experience. - This functionality is useful in two cases: - 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node - like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using - lazy evaluation to let it conditionally disable itself.) - 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. - (I would recommend not making nodes like this in the future -- instead, make multiple nodes with - different outputs. Unfortunately, there are several popular existing nodes using this pattern.) - """ - def __init__(self, message): - self.message = message - diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index 8595e942d..496d2c634 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix): return new_graph, tuple(new_outputs) +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + def __init__(self, message): + self.message = message diff --git a/execution.py b/execution.py index cde14c52f..952f0cc5c 100644 --- a/execution.py +++ b/execution.py @@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext +from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func +from comfy_api.latest import io class ExecutionResult(Enum): @@ -56,7 +58,15 @@ class IsChangedCache: node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if not hasattr(class_def, "IS_CHANGED"): + has_is_changed = False + is_changed_name = None + if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: + has_is_changed = True + is_changed_name = "fingerprint_inputs" + elif hasattr(class_def, "IS_CHANGED"): + has_is_changed = True + is_changed_name = "IS_CHANGED" + if not has_is_changed: self.is_changed[node_id] = False return self.is_changed[node_id] @@ -65,9 +75,9 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED") + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: @@ -126,9 +136,14 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): - valid_inputs = class_def.INPUT_TYPES() + is_v3 = issubclass(class_def, _ComfyNodeInternal) + if is_v3: + valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + else: + valid_inputs = class_def.INPUT_TYPES() input_data_all = {} missing_keys = {} + hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -153,22 +168,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e elif input_category is not None: input_data_all[x] = [input_data] - if "hidden" in valid_inputs: - h = valid_inputs["hidden"] - for x in h: - if h[x] == "PROMPT": - input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] - if h[x] == "DYNPROMPT": - input_data_all[x] = [dynprompt] - if h[x] == "EXTRA_PNGINFO": - input_data_all[x] = [extra_data.get('extra_pnginfo', None)] - if h[x] == "UNIQUE_ID": - input_data_all[x] = [unique_id] - if h[x] == "AUTH_TOKEN_COMFY_ORG": - input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] - if h[x] == "API_KEY_COMFY_ORG": - input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys + if is_v3: + if schema.hidden: + if io.Hidden.prompt in schema.hidden: + hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + if io.Hidden.dynprompt in schema.hidden: + hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt + if io.Hidden.extra_pnginfo in schema.hidden: + hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + if io.Hidden.unique_id in schema.hidden: + hidden_inputs_v3[io.Hidden.unique_id] = unique_id + if io.Hidden.auth_token_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + if io.Hidden.api_key_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + else: + if "hidden" in valid_inputs: + h = valid_inputs["hidden"] + for x in h: + if h[x] == "PROMPT": + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] + if h[x] == "EXTRA_PNGINFO": + input_data_all[x] = [extra_data.get('extra_pnginfo', None)] + if h[x] == "UNIQUE_ID": + input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] + if h[x] == "API_KEY_COMFY_ORG": + input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + return input_data_all, missing_keys, hidden_inputs_v3 map_node_over_list = None #Don't hook this please @@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -214,7 +244,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if execution_block is None: if pre_execute_cb is not None and index is not None: pre_execute_cb(index) - f = getattr(obj, func) + # V3 + if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + f = make_locked_method_func(type_obj, func, class_clone) + # V1 + else: + f = getattr(obj, func) if inspect.iscoroutinefunction(f): async def async_wrapper(f, prompt_id, unique_id, list_index, args): with CurrentNodeContext(prompt_id, unique_id, list_index): @@ -266,8 +311,8 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj): result = tuple([result] * len(obj.RETURN_TYPES)) results.append(result) subgraph_results.append((None, result)) + elif isinstance(r, _NodeOutputInternal): + # V3 + if r.ui is not None: + if isinstance(r.ui, dict): + uis.append(r.ui) + else: + uis.append(r.ui.as_dict()) + if r.expand is not None: + has_subgraph = True + new_graph = r.expand + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif r.result is not None: + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: if isinstance(r, ExecutionBlocker): r = tuple([r] * len(obj.RETURN_TYPES)) @@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, obj = class_def() caches.objects.set(unique_id, obj) - if hasattr(obj, "check_lazy_status"): - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) + if issubclass(class_def, _ComfyNodeInternal): + lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None + else: + lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + if lazy_status_present: + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): validate_function_inputs = [] validate_has_kwargs = False - if hasattr(obj_class, "VALIDATE_INPUTS"): - argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + if issubclass(obj_class, _ComfyNodeInternal): + validate_function_name = "validate_inputs" + validate_function = first_real_override(obj_class, validate_function_name) + else: + validate_function_name = "VALIDATE_INPUTS" + validate_function = getattr(obj_class, validate_function_name, None) + if validate_function is not None: + argspec = inspect.getfullargspec(validate_function) validate_function_inputs = argspec.args validate_has_kwargs = argspec.varkw is not None received_types = {} @@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _ = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - #ret = obj_class.VALIDATE_INPUTS(**input_filtered) - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS") + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): diff --git a/nodes.py b/nodes.py index 54e530388..da4a46366 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import os import sys import json import hashlib +import inspect import traceback import math import time @@ -29,6 +30,7 @@ import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions +from comfy_api.latest import io, ComfyExtension import comfy.clip_vision @@ -2152,6 +2154,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom if os.path.isdir(web_dir): EXTENSION_WEB_DIRS[module_name] = web_dir + # V1 node definition if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): if name not in ignore: @@ -2160,8 +2163,38 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True + # V3 Extension Definition + elif hasattr(module, "comfy_entrypoint"): + entrypoint = getattr(module, "comfy_entrypoint") + if not callable(entrypoint): + logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.") + return False + try: + if inspect.iscoroutinefunction(entrypoint): + extension = await entrypoint() + else: + extension = entrypoint() + if not isinstance(extension, ComfyExtension): + logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.") + return False + node_list = await extension.get_node_list() + if not isinstance(node_list, list): + logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.") + return False + for node_cls in node_list: + node_cls: io.ComfyNode + schema = node_cls.GET_SCHEMA() + if schema.node_id not in ignore: + NODE_CLASS_MAPPINGS[schema.node_id] = node_cls + node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) + if schema.display_name is not None: + NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name + return True + except Exception as e: + logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") + return False else: - logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") return False except Exception as e: logging.warning(traceback.format_exc()) @@ -2286,7 +2319,7 @@ async def init_builtin_extra_nodes(): "nodes_string.py", "nodes_camera_trajectory.py", "nodes_edit_model.py", - "nodes_tcfg.py" + "nodes_tcfg.py", ] import_failed = [] diff --git a/server.py b/server.py index 3e06d2fbb..0553a0dd7 100644 --- a/server.py +++ b/server.py @@ -30,6 +30,7 @@ from comfy_api import feature_flags import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager +from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -591,6 +592,8 @@ class PromptServer(): def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + if issubclass(obj_class, _ComfyNodeInternal): + return obj_class.GET_NODE_INFO_V1() info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} From 5ee381c058d606209dcafb568af20196e7884fc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 31 Jul 2025 20:33:27 -0700 Subject: [PATCH 005/325] Fix WanFirstLastFrameToVideo node when no clip vision. (#9134) --- comfy_extras/nodes_wan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0b92c68ac..0067d054d 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo: positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + clip_vision_output = None if clip_vision_start_image is not None: clip_vision_output = clip_vision_start_image From 4696d74305e98a96bda5685b7f11d6ba167c2ed3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 1 Aug 2025 15:06:18 +0800 Subject: [PATCH 006/325] update template to 0.1.45 (#9135) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8f2f6a56c..3828c5b91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.44 +comfyui-workflow-templates==0.1.45 comfyui-embedded-docs==0.2.4 torch torchsde From 1e638a140b2f459595fafc73ade5ea5b4024d4b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 1 Aug 2025 02:25:38 -0700 Subject: [PATCH 007/325] Tiny wan vae optimizations. (#9136) --- comfy/ldm/wan/vae.py | 13 +++++++++---- comfy/ldm/wan/vae2_2.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 6b07840fc..791596938 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d): self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): + if cache_list is not None: + cache_x = cache_list[cache_idx] + cache_list[cache_idx] = None + padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + del cache_x x = F.pad(x, padding) return super().forward(x) @@ -166,7 +171,7 @@ class ResidualBlock(nn.Module): if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) + old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -178,12 +183,12 @@ class ResidualBlock(nn.Module): cache_x.device), cache_x ], dim=2) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + self.shortcut(old_x) class AttentionBlock(nn.Module): diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index b9c2d1a26..1f6d584a2 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -151,7 +151,7 @@ class ResidualBlock(nn.Module): ], dim=2, ) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: From bff60b5cfc10d1b037a95746226ac6698dc3e373 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Aug 2025 20:03:22 -0400 Subject: [PATCH 008/325] ComfyUI version 0.3.48 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 20a2e892a..7b29e338d 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.47" +__version__ = "0.3.48" diff --git a/pyproject.toml b/pyproject.toml index 244fdd232..256677fad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.47" +version = "0.3.48" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 3dfefc88d00bde744b729b073058a57e149cddc1 Mon Sep 17 00:00:00 2001 From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com> Date: Sat, 2 Aug 2025 03:02:06 +0100 Subject: [PATCH 009/325] API for Recently Used Items (#8792) * feat: add file creation time to model file metadata and user file info * fix linting --- app/model_manager.py | 21 ++++++++++++++++----- app/user_manager.py | 4 +++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index 74d942fb8..ab36bca74 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -130,10 +130,21 @@ class ModelFileManager: for file_name in filenames: try: - relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) - result.append(relative_path) - except: - logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") + full_path = os.path.join(dirpath, file_name) + relative_path = os.path.relpath(full_path, directory) + + # Get file metadata + file_info = { + "name": relative_path, + "pathIndex": pathIndex, + "modified": os.path.getmtime(full_path), # Add modification time + "created": os.path.getctime(full_path), # Add creation time + "size": os.path.getsize(full_path) # Add file size + } + result.append(file_info) + + except Exception as e: + logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") continue for d in subdirs: @@ -144,7 +155,7 @@ class ModelFileManager: logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue - return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() + return result, dirs, time.perf_counter() def get_model_previews(self, filepath: str) -> list[str | BytesIO]: dirname = os.path.dirname(filepath) diff --git a/app/user_manager.py b/app/user_manager.py index d31da5b9b..0ec3e46ea 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -20,13 +20,15 @@ class FileInfo(TypedDict): path: str size: int modified: int + created: int def get_file_info(path: str, relative_to: str) -> FileInfo: return { "path": os.path.relpath(path, relative_to).replace(os.sep, '/'), "size": os.path.getsize(path), - "modified": os.path.getmtime(path) + "modified": os.path.getmtime(path), + "created": os.path.getctime(path) } From fbcc23945dc377c8623bbee6132f15a93ac0c84a Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sun, 3 Aug 2025 02:15:29 +0800 Subject: [PATCH 010/325] Update template to 0.1.47 (#9153) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3828c5b91..ffa7dce65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.45 +comfyui-workflow-templates==0.1.47 comfyui-embedded-docs==0.2.4 torch torchsde From 5f582a97572e87ebfa655d379e8c8f7611c0249f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 Aug 2025 12:00:13 -0700 Subject: [PATCH 011/325] Make sure all the conds are on the right device. (#9151) --- comfy/model_base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b7978949..3ff8106d7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -106,10 +106,12 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) -def convert_tensor(extra, dtype): +def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = extra.to(dtype=dtype, device=device) + else: + extra = extra.to(device=device) return extra @@ -174,15 +176,16 @@ class BaseModel(torch.nn.Module): context = context.to(dtype) extra_conds = {} + device = xc.device for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): - extra = convert_tensor(extra, dtype) + extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(convert_tensor(ext, dtype)) + ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra From 13aaa66ec21c397240a9b972d818430b39112588 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 Aug 2025 12:09:23 -0700 Subject: [PATCH 012/325] Make sure context is on the right device. (#9154) --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3ff8106d7..4556ee138 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -171,12 +171,12 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype xc = xc.to(dtype) + device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype) + context = context.to(dtype=dtype, device=device) extra_conds = {} - device = xc.device for o in kwargs: extra = kwargs[o] From aebac221937b511d46fe601656acdc753435b849 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 3 Aug 2025 04:08:11 -0700 Subject: [PATCH 013/325] Cleanup. (#9160) --- comfy/controlnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9a47b86f2..6ed8bd756 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -43,7 +43,6 @@ if TYPE_CHECKING: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From 182f90b5eca2baa25474223759039925b286d562 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:11:53 -0700 Subject: [PATCH 014/325] Lower cond vram use by casting at the same time as device transfer. (#9159) --- comfy/conds.py | 14 +++++++------- comfy/model_base.py | 6 +++--- comfy/samplers.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 2af2a43a3..f2564e7ef 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -10,8 +10,8 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) - def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + def process_cond(self, batch_size, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) def can_concat(self, other): if self.cond.shape != other.cond.shape: @@ -29,14 +29,14 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): - def process_cond(self, batch_size, device, area, **kwargs): + def process_cond(self, batch_size, area, **kwargs): data = self.cond if area is not None: dims = len(area) // 2 for i in range(dims): data = data.narrow(i + 2, area[i + dims], area[i]) - return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) class CONDCrossAttn(CONDRegular): @@ -73,7 +73,7 @@ class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): @@ -92,10 +92,10 @@ class CONDList(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): out = [] for c in self.cond: - out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) return self._copy_with(out) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4556ee138..3a9c031ea 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -109,9 +109,9 @@ def model_sampling(model_config, model_type): def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype=dtype, device=device) + extra = comfy.model_management.cast_to_device(extra, device, dtype) else: - extra = extra.to(device=device) + extra = comfy.model_management.cast_to_device(extra, device, None) return extra @@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module): device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype=dtype, device=device) + context = comfy.model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: diff --git a/comfy/samplers.py b/comfy/samplers.py index e93d2a315..ad2f40cdc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in): conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area) hooks = conds.get('hooks', None) control = conds.get('control', None) From 140ffc7fdc53e810030f060e421c1f528c2d2ab9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:28:12 -0700 Subject: [PATCH 015/325] Fix broken controlnet from last PR. (#9167) --- comfy/controlnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6ed8bd756..988acdb57 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -28,6 +28,7 @@ import comfy.model_detection import comfy.model_patcher import comfy.ops import comfy.latent_formats +import comfy.model_base import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -264,12 +265,12 @@ class ControlNet(ControlBase): for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: - extra[c] = temp.to(dtype) + extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra) return self.control_merge(control, control_prev, output_dtype=None) def copy(self): From 7991341e89cab521441641505ac4b0eea292a829 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:02:40 -0700 Subject: [PATCH 016/325] Various fixes for broken things from earlier PR. (#9168) --- comfy/model_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3a9c031ea..f9591f292 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -401,7 +401,7 @@ class SD21UNCLIP(BaseModel): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: - return torch.zeros((1, self.adm_channels)) + return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"] + return args["pooled_output"].to(device=args["device"]) class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): @@ -615,9 +615,11 @@ class IP2P: if image is None: image = torch.zeros_like(noise) + else: + image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) @@ -696,7 +698,7 @@ class StableCascade_B(BaseModel): #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) - out["effnet"] = comfy.conds.CONDRegular(prior) + out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out @@ -1161,10 +1163,10 @@ class WAN21_Vace(WAN21): vace_frames_out = [] for j in range(len(vace_frames)): - vf = vace_frames[j].clone() + vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) - vf = torch.cat([vf, mask[j]], dim=1) + vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1) From 84f9759424ccbd8de710960c79f0f1d28eef2776 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:20:12 -0700 Subject: [PATCH 017/325] Add some warnings and prevent crash when cond devices don't match. (#9169) --- comfy/conds.py | 7 +++++++ comfy/model_base.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy/conds.py b/comfy/conds.py index f2564e7ef..5af3e93ea 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -16,6 +17,9 @@ class CONDRegular: def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): diff --git a/comfy/model_base.py b/comfy/model_base.py index f9591f292..2db81e244 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"].to(device=args["device"]) + return args["pooled_output"] class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): From 03895dea7c4a6cc93fa362cd11ca450217d74b13 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:33:04 -0700 Subject: [PATCH 018/325] Fix another issue with the PR. (#9170) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2db81e244..a06686436 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,7 +162,7 @@ class BaseModel(torch.nn.Module): xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) + xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype() From c012400240d4867cd63a45220eb791b91ad47617 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 19:53:25 -0700 Subject: [PATCH 019/325] Initial support for qwen image model. (#9179) --- comfy/ldm/qwen_image/model.py | 399 ++++++++++++++++++++++++++++++ comfy/model_base.py | 12 + comfy/model_detection.py | 7 +- comfy/sd.py | 12 +- comfy/supported_models.py | 32 ++- comfy/text_encoders/llama.py | 26 ++ comfy/text_encoders/qwen_image.py | 71 ++++++ nodes.py | 2 +- 8 files changed, 557 insertions(+), 4 deletions(-) create mode 100644 comfy/ldm/qwen_image/model.py create mode 100644 comfy/text_encoders/qwen_image.py diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py new file mode 100644 index 000000000..ff631a60f --- /dev/null +++ b/comfy/ldm/qwen_image/model.py @@ -0,0 +1,399 @@ +# https://github.com/QwenLM/Qwen-Image (Apache 2.0) +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND + + +class GELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device) + self.approximate = approximate + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = F.gelu(hidden_states, approximate=self.approximate) + return hidden_states + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + inner_dim=None, + bias: bool = True, + dtype=None, device=None, operations=None + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.net = nn.ModuleList([]) + self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations)) + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + dtype=dtype, + device=device, + operations=operations + ) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + return timesteps_emb + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + dim_head: int = 64, + heads: int = 8, + dropout: float = 0.0, + bias: bool = False, + eps: float = 1e-5, + out_bias: bool = True, + out_dim: int = None, + out_context_dim: int = None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim + self.heads = heads + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.dropout = dropout + + # Q/K normalization + self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + # Image stream projections + self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Text stream projections + self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Output projections + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device), + nn.Dropout(dropout) + ]) + self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_txt = encoder_hidden_states.shape[1] + + img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) + img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) + img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + + txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + img_attn_output = self.to_out[0](img_attn_output) + img_attn_output = self.to_out[1](img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.attn = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def _modulate(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod_params = self.img_mod(temb) + txt_mod_params = self.txt_mod(temb) + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) + + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + img_attn_output, txt_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + return encoder_hidden_states, hidden_states + + +class LastLayer(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=False, + eps=1e-6, + bias=True, + dtype=None, device=None, operations=None + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device) + self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class QwenImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + dtype=dtype, + device=device, + operations=operations + ) + + self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device) + self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device) + self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations + ) + for _ in range(num_layers) + ]) + + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) + self.gradient_checkpointing = False + + def pos_embeds(self, x, context): + bs, c, t, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_start = round(max(h_len, w_len)) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + image_rotary_emb = self.pos_embeds(x, context) + + orig_shape = x.shape + hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + return hidden_states.reshape(orig_shape) diff --git a/comfy/model_base.py b/comfy/model_base.py index a06686436..8a2d9cbe6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.qwen_image.model import comfy.model_management import comfy.patcher_extension @@ -1308,3 +1309,14 @@ class Omnigen2(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class QwenImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9fc1f42de..8b57ebd2f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image + dit_config = {} + dit_config["image_model"] = "qwen_image" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') hidden_size = state_dict["x_embedder.bias"].shape[0] sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) - elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) diff --git a/comfy/sd.py b/comfy/sd.py index e0498e585..bb5d61fb3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -47,6 +47,7 @@ import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image import comfy.model_patcher import comfy.lora @@ -771,6 +772,7 @@ class CLIPType(Enum): CHROMA = 15 ACE = 16 OMNIGEN2 = 17 + QWEN_IMAGE = 18 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -791,6 +793,7 @@ class TEModel(Enum): T5_XXL_OLD = 8 GEMMA_2_2B = 9 QWEN25_3B = 10 + QWEN25_7B = 11 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -812,7 +815,11 @@ def detect_te_model(sd): if 'model.layers.0.post_feedforward_layernorm.weight' in sd: return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: - return TEModel.QWEN25_3B + weight = sd['model.layers.0.self_attn.k_proj.bias'] + if weight.shape[0] == 256: + return TEModel.QWEN25_3B + if weight.shape[0] == 512: + return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: return TEModel.LLAMA3_8 return None @@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer + elif te_model == TEModel.QWEN25_7B: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8f3f4652d..880055bd3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2 import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image from . import supported_models_base from . import latent_formats @@ -1229,7 +1230,36 @@ class Omnigen2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) +class QwenImage(supported_models_base.BASE): + unet_config = { + "image_model": "qwen_image", + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] + sampling_settings = { + "multiplier": 1.0, + "shift": 2.6, + } + + memory_usage_factor = 1.8 #TODO + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.QwenImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 7fbd0f604..1da6a0c94 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -43,6 +43,23 @@ class Qwen25_3BConfig: mlp_activation = "silu" qkv_bias = True +@dataclass +class Qwen25_7BVLI_Config: + vocab_size: int = 152064 + hidden_size: int = 3584 + intermediate_size: int = 18944 + num_hidden_layers: int = 28 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + max_position_embeddings: int = 128000 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = True + @dataclass class Gemma2_2B_Config: vocab_size: int = 256000 @@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen25_7BVLI(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen25_7BVLI_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py new file mode 100644 index 000000000..ce5c98097 --- /dev/null +++ b/comfy/text_encoders/qwen_image.py @@ -0,0 +1,71 @@ +from transformers import Qwen2Tokenizer +from comfy import sd1_clip +import comfy.text_encoders.llama +import os +import torch +import numbers + +class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class QwenImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class QwenImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + out = out[:, template_end:] + + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") # attention mask is useless if no masked elements + + return out, pooled, extra + + +def te(dtype_llama=None, llama_scaled_fp8=None): + class QwenImageTEModel_(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/nodes.py b/nodes.py index da4a46366..9bedbcaca 100644 --- a/nodes.py +++ b/nodes.py @@ -925,7 +925,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From f69609bbd6c20f4814e313f8974656b187a9bee2 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 4 Aug 2025 22:52:25 -0700 Subject: [PATCH 020/325] Add Veo3 video generation node with audio support (#9110) - Create new Veo3VideoGenerationNode that extends VeoVideoGenerationNode - Add support for generateAudio parameter (only for Veo3 models) - Support new Veo3 models: veo-3.0-generate-001, veo-3.0-fast-generate-001 - Fix Veo3 duration constraint to 8 seconds only - Update original node to be clearly Veo 2 only - Update API paths to use model parameter: /proxy/veo/{model}/generate - Regenerate API types from staging to include generateAudio parameter - Fix TripoModelVersion enum reference after regeneration - Mark generated API types file in .gitattributes --- .gitattributes | 1 + comfy_api_nodes/apis/__init__.py | 2656 ++++++++++++++++++++++++++++- comfy_api_nodes/apis/tripo_api.py | 2 +- comfy_api_nodes/nodes_veo2.py | 98 +- 4 files changed, 2664 insertions(+), 93 deletions(-) diff --git a/.gitattributes b/.gitattributes index 4391de678..5b3c15bb4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ /web/assets/** linguist-generated /web/** linguist-vendored +comfy_api_nodes/apis/__init__.py linguist-generated diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 086028abe..54298e8a9 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: filtered-openapi.yaml -# timestamp: 2025-07-06T09:47:31+00:00 +# timestamp: 2025-07-30T08:54:00+00:00 from __future__ import annotations @@ -37,6 +37,99 @@ class AuditLog(BaseModel): ) +class BFLAsyncResponse(BaseModel): + id: str = Field(..., title='Id') + polling_url: str = Field(..., title='Polling Url') + + +class BFLAsyncWebhookResponse(BaseModel): + id: str = Field(..., title='Id') + status: str = Field(..., title='Status') + webhook_url: str = Field(..., title='Webhook Url') + + +class CannyHighThreshold(RootModel[int]): + root: int = Field( + ..., + description='High threshold for Canny edge detection', + ge=0, + le=500, + title='Canny High Threshold', + ) + + +class CannyLowThreshold(RootModel[int]): + root: int = Field( + ..., + description='Low threshold for Canny edge detection', + ge=0, + le=500, + title='Canny Low Threshold', + ) + + +class Guidance(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.0, + le=100.0, + title='Guidance', + ) + + +class Steps(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + ge=15, + le=50, + title='Steps', + ) + + +class WebhookUrl(RootModel[AnyUrl]): + root: AnyUrl = Field( + ..., description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxKontextMaxGenerateRequest(BaseModel): + guidance: Optional[float] = Field( + 3, description='The guidance scale for generation', ge=1.0, le=20.0 + ) + input_image: str = Field(..., description='Base64 encoded image to be edited') + prompt: str = Field( + ..., description='The text prompt describing what to edit on the image' + ) + steps: Optional[int] = Field( + 50, description='Number of inference steps', ge=1, le=50 + ) + + +class BFLFluxKontextMaxGenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + +class BFLFluxKontextProGenerateRequest(BaseModel): + guidance: Optional[float] = Field( + 3, description='The guidance scale for generation', ge=1.0, le=20.0 + ) + input_image: str = Field(..., description='Base64 encoded image to be edited') + prompt: str = Field( + ..., description='The text prompt describing what to edit on the image' + ) + steps: Optional[int] = Field( + 50, description='Number of inference steps', ge=1, le=50 + ) + + +class BFLFluxKontextProGenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + class OutputFormat(str, Enum): jpeg = 'jpeg' png = 'png' @@ -68,6 +161,67 @@ class BFLFluxPro11GenerateResponse(BaseModel): polling_url: str = Field(..., description='URL to poll for results') +class Bottom(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the bottom of the image', + ge=0, + le=2048, + title='Bottom', + ) + + +class Guidance2(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.5, + le=100.0, + title='Guidance', + ) + + +class Left(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the left side of the image', + ge=0, + le=2048, + title='Left', + ) + + +class Right(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the right side of the image', + ge=0, + le=2048, + title='Right', + ) + + +class Steps2(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + examples=[50], + ge=15, + le=50, + title='Steps', + ) + + +class Top(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the top of the image', + ge=0, + le=2048, + title='Top', + ) + + class BFLFluxProGenerateRequest(BaseModel): guidance_scale: Optional[float] = Field( None, description='The guidance scale for generation.', ge=1.0, le=20.0 @@ -96,7 +250,71 @@ class BFLFluxProGenerateResponse(BaseModel): polling_url: str = Field(..., description='URL to poll for the generation result.') +class BFLOutputFormat(str, Enum): + jpeg = 'jpeg' + png = 'png' + + +class BFLValidationError(BaseModel): + loc: List[Union[str, int]] = Field(..., title='Location') + msg: str = Field(..., title='Message') + type: str = Field(..., title='Error Type') + + class Status(str, Enum): + success = 'success' + not_found = 'not_found' + error = 'error' + + +class ClaimMyNodeRequest(BaseModel): + GH_TOKEN: str = Field( + ..., description='GitHub token to verify if the user owns the repo of the node' + ) + + +class ComfyNode(BaseModel): + category: Optional[str] = Field( + None, + description='UI category where the node is listed, used for grouping nodes.', + ) + comfy_node_name: Optional[str] = Field( + None, description='Unique identifier for the node' + ) + deprecated: Optional[bool] = Field( + None, + description='Indicates if the node is deprecated. Deprecated nodes are hidden in the UI.', + ) + description: Optional[str] = Field( + None, description="Brief description of the node's functionality or purpose." + ) + experimental: Optional[bool] = Field( + None, + description='Indicates if the node is experimental, subject to changes or removal.', + ) + function: Optional[str] = Field( + None, description='Name of the entry-point function to execute the node.' + ) + input_types: Optional[str] = Field(None, description='Defines input parameters') + output_is_list: Optional[List[bool]] = Field( + None, description='Boolean values indicating if each output is a list.' + ) + return_names: Optional[str] = Field( + None, description='Names of the outputs for clarity in workflows.' + ) + return_types: Optional[str] = Field( + None, description='Specifies the types of outputs produced by the node.' + ) + + +class ComfyNodeCloudBuildInfo(BaseModel): + build_id: Optional[str] = None + location: Optional[str] = None + project_id: Optional[str] = None + project_number: Optional[str] = None + + +class Status1(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' @@ -113,7 +331,7 @@ class ComputerToolCall(BaseModel): description='An identifier used when responding to the tool call with output.\n', ) id: str = Field(..., description='The unique ID of the computer call.') - status: Status = Field( + status: Status1 = Field( ..., description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) @@ -156,6 +374,7 @@ class Customer(BaseModel): None, description='The date and time the user was created' ) email: Optional[str] = Field(None, description='The email address for this user') + has_fund: Optional[bool] = Field(None, description='Whether the user has funds') id: str = Field(..., description='The firebase UID of the user') is_admin: Optional[bool] = Field(None, description='Whether the user is an admin') metronome_id: Optional[str] = Field(None, description='The Metronome customer ID') @@ -194,6 +413,16 @@ class Type2(str, Enum): message = 'message' +class Error(BaseModel): + details: Optional[List[str]] = Field( + None, + description='Optional detailed information about the error or hints for resolving it.', + ) + message: Optional[str] = Field( + None, description='A clear and concise description of the error.' + ) + + class ErrorResponse(BaseModel): error: str message: str @@ -221,7 +450,7 @@ class Result(BaseModel): ) -class Status1(str, Enum): +class Status2(str, Enum): in_progress = 'in_progress' searching = 'searching' completed = 'completed' @@ -241,7 +470,7 @@ class FileSearchToolCall(BaseModel): results: Optional[List[Result]] = Field( None, description='The results of the file search tool call.\n' ) - status: Status1 = Field( + status: Status2 = Field( ..., description='The status of the file search tool call. One of `in_progress`, \n`searching`, `incomplete` or `failed`,\n', ) @@ -266,7 +495,7 @@ class FunctionTool(BaseModel): type: Literal['FunctionTool'] = Field(..., description='The type of tool') -class Status2(str, Enum): +class Status3(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' @@ -288,7 +517,7 @@ class FunctionToolCall(BaseModel): None, description='The unique ID of the function tool call.\n' ) name: str = Field(..., description='The name of the function to run.\n') - status: Optional[Status2] = Field( + status: Optional[Status3] = Field( None, description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) @@ -442,6 +671,95 @@ class GeminiVideoMetadata(BaseModel): startOffset: Optional[GeminiOffset] = None +class GitCommitSummary(BaseModel): + author: Optional[str] = Field(None, description='The author of the commit') + branch_name: Optional[str] = Field( + None, description='The branch where the commit was made' + ) + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_name: Optional[str] = Field(None, description='The name of the commit') + status_summary: Optional[Dict[str, str]] = Field( + None, description='A map of operating system to status pairs' + ) + timestamp: Optional[datetime] = Field( + None, description='The timestamp when the commit was made' + ) + + +class GithubEnterprise(BaseModel): + avatar_url: str = Field(..., description='URL to the enterprise avatar') + created_at: datetime = Field(..., description='When the enterprise was created') + description: Optional[str] = Field(None, description='The enterprise description') + html_url: str = Field(..., description='The HTML URL of the enterprise') + id: int = Field(..., description='The enterprise ID') + name: str = Field(..., description='The enterprise name') + node_id: str = Field(..., description='The enterprise node ID') + slug: str = Field(..., description='The enterprise slug') + updated_at: datetime = Field( + ..., description='When the enterprise was last updated' + ) + website_url: Optional[str] = Field(None, description='The enterprise website URL') + + +class RepositorySelection(str, Enum): + selected = 'selected' + all = 'all' + + +class GithubOrganization(BaseModel): + avatar_url: str = Field(..., description="URL to the organization's avatar") + description: Optional[str] = Field(None, description='The organization description') + events_url: str = Field(..., description="The API URL of the organization's events") + hooks_url: str = Field(..., description="The API URL of the organization's hooks") + id: int = Field(..., description='The organization ID') + issues_url: str = Field(..., description="The API URL of the organization's issues") + login: str = Field(..., description="The organization's login name") + members_url: str = Field( + ..., description="The API URL of the organization's members" + ) + node_id: str = Field(..., description='The organization node ID') + public_members_url: str = Field( + ..., description="The API URL of the organization's public members" + ) + repos_url: str = Field( + ..., description="The API URL of the organization's repositories" + ) + url: str = Field(..., description='The API URL of the organization') + + +class State(str, Enum): + uploaded = 'uploaded' + open = 'open' + + +class Action(str, Enum): + published = 'published' + unpublished = 'unpublished' + created = 'created' + edited = 'edited' + deleted = 'deleted' + prereleased = 'prereleased' + released = 'released' + + +class Type7(str, Enum): + Bot = 'Bot' + User = 'User' + Organization = 'Organization' + + +class GithubUser(BaseModel): + avatar_url: str = Field(..., description="URL to the user's avatar") + gravatar_id: Optional[str] = Field(None, description="The user's gravatar ID") + html_url: str = Field(..., description='The HTML URL of the user') + id: int = Field(..., description="The user's ID") + login: str = Field(..., description="The user's login name") + node_id: str = Field(..., description="The user's node ID") + site_admin: bool = Field(..., description='Whether the user is a site admin') + type: Type7 = Field(..., description='The type of user') + url: str = Field(..., description='The API URL of the user') + + class IdeogramColorPalette1(BaseModel): name: str = Field(..., description='Name of the preset color palette') @@ -689,7 +1007,7 @@ class Includable(str, Enum): computer_call_output_output_image_url = 'computer_call_output.output.image_url' -class Type7(str, Enum): +class Type8(str, Enum): input_file = 'input_file' @@ -703,7 +1021,7 @@ class InputFileContent(BaseModel): filename: Optional[str] = Field( None, description='The name of the file to be sent to the model.' ) - type: Type7 = Field( + type: Type8 = Field( ..., description='The type of the input item. Always `input_file`.' ) @@ -714,7 +1032,7 @@ class Detail(str, Enum): auto = 'auto' -class Type8(str, Enum): +class Type9(str, Enum): input_image = 'input_image' @@ -730,7 +1048,7 @@ class InputImageContent(BaseModel): None, description='The URL of the image to be sent to the model. A fully qualified URL or base64 encoded image in a data URL.', ) - type: Type8 = Field( + type: Type9 = Field( ..., description='The type of the input item. Always `input_image`.' ) @@ -741,17 +1059,17 @@ class Role3(str, Enum): developer = 'developer' -class Type9(str, Enum): +class Type10(str, Enum): message = 'message' -class Type10(str, Enum): +class Type11(str, Enum): input_text = 'input_text' class InputTextContent(BaseModel): text: str = Field(..., description='The text input to the model.') - type: Type10 = Field( + type: Type11 = Field( ..., description='The type of the input item. Always `input_text`.' ) @@ -923,7 +1241,7 @@ class ResourcePackType(str, Enum): constant_period = 'constant_period' -class Status4(str, Enum): +class Status5(str, Enum): toBeOnline = 'toBeOnline' online = 'online' expired = 'expired' @@ -949,7 +1267,7 @@ class ResourcePackSubscribeInfo(BaseModel): None, description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', ) - status: Optional[Status4] = Field(None, description='Resource Package Status') + status: Optional[Status5] = Field(None, description='Resource Package Status') total_quantity: Optional[float] = Field(None, description='Total quantity') @@ -1113,7 +1431,7 @@ class LumaError(BaseModel): detail: Optional[str] = Field(None, description='The error message') -class Type11(str, Enum): +class Type12(str, Enum): generation = 'generation' @@ -1153,7 +1471,7 @@ class LumaImageRef(BaseModel): ) -class Type12(str, Enum): +class Type13(str, Enum): image = 'image' @@ -1223,6 +1541,36 @@ class LumaVideoModelOutputResolution( root: Union[LumaVideoModelOutputResolution1, str] +class MachineStats(BaseModel): + cpu_capacity: Optional[str] = Field(None, description='Total CPU on the machine.') + disk_capacity: Optional[str] = Field( + None, description='Total disk capacity on the machine.' + ) + gpu_type: Optional[str] = Field( + None, description='The GPU type. eg. NVIDIA Tesla K80' + ) + initial_cpu: Optional[str] = Field( + None, description='Initial CPU available before the job starts.' + ) + initial_disk: Optional[str] = Field( + None, description='Initial disk available before the job starts.' + ) + initial_ram: Optional[str] = Field( + None, description='Initial RAM available before the job starts.' + ) + machine_name: Optional[str] = Field(None, description='Name of the machine.') + memory_capacity: Optional[str] = Field( + None, description='Total memory on the machine.' + ) + os_version: Optional[str] = Field( + None, description='The operating system version. eg. Ubuntu Linux 20.04' + ) + pip_freeze: Optional[str] = Field(None, description='The pip freeze output') + vram_time_series: Optional[Dict[str, Any]] = Field( + None, description='Time series of VRAM usage.' + ) + + class MinimaxBaseResponse(BaseModel): status_code: int = Field( ..., @@ -1251,7 +1599,7 @@ class MinimaxFileRetrieveResponse(BaseModel): file: File -class Status5(str, Enum): +class Status6(str, Enum): Queueing = 'Queueing' Preparing = 'Preparing' Processing = 'Processing' @@ -1265,7 +1613,7 @@ class MinimaxTaskResultResponse(BaseModel): None, description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', ) - status: Status5 = Field( + status: Status6 = Field( ..., description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", ) @@ -1326,6 +1674,22 @@ class MinimaxVideoGenerationResponse(BaseModel): ) +class Modality(str, Enum): + MODALITY_UNSPECIFIED = 'MODALITY_UNSPECIFIED' + TEXT = 'TEXT' + IMAGE = 'IMAGE' + VIDEO = 'VIDEO' + AUDIO = 'AUDIO' + DOCUMENT = 'DOCUMENT' + + +class ModalityTokenCount(BaseModel): + modality: Optional[Modality] = None + tokenCount: Optional[int] = Field( + None, description='Number of tokens for the given modality.' + ) + + class Truncation(str, Enum): disabled = 'disabled' auto = 'auto' @@ -1391,13 +1755,13 @@ class MoonvalleyTextToVideoInferenceParams(BaseModel): 0, description='Index of the conditioning frame' ) cooldown_steps: Optional[int] = Field( - None, description='Number of cooldown steps (calculated based on num_frames)' + 75, description='Number of cooldown steps (calculated based on num_frames)' ) fps: Optional[int] = Field( 24, description='Frames per second of the generated video' ) guidance_scale: Optional[float] = Field( - 12.5, description='Guidance scale for generation control' + 10, description='Guidance scale for generation control' ) height: Optional[int] = Field( 1080, description='Height of the generated video in pixels' @@ -1421,7 +1785,7 @@ class MoonvalleyTextToVideoInferenceParams(BaseModel): True, description='Whether to use timestep transformation' ) warmup_steps: Optional[int] = Field( - None, description='Number of warmup steps (calculated based on num_frames)' + 0, description='Number of warmup steps (calculated based on num_frames)' ) width: Optional[int] = Field( 1920, description='Width of the generated video in pixels' @@ -1463,10 +1827,10 @@ class MoonvalleyVideoToVideoInferenceParams(BaseModel): 0, description='Index of the conditioning frame' ) cooldown_steps: Optional[int] = Field( - None, description='Number of cooldown steps (calculated based on num_frames)' + 36, description='Number of cooldown steps (calculated based on num_frames)' ) guidance_scale: Optional[float] = Field( - 12.5, description='Guidance scale for generation control' + 15, description='Guidance scale for generation control' ) negative_prompt: Optional[str] = Field(None, description='Negative prompt text') seed: Optional[int] = Field( @@ -1486,7 +1850,7 @@ class MoonvalleyVideoToVideoInferenceParams(BaseModel): True, description='Whether to use timestep transformation' ) warmup_steps: Optional[int] = Field( - None, description='Number of warmup steps (calculated based on num_frames)' + 24, description='Number of warmup steps (calculated based on num_frames)' ) @@ -1507,6 +1871,34 @@ class MoonvalleyVideoToVideoRequest(BaseModel): ) +class NodeStatus(str, Enum): + NodeStatusActive = 'NodeStatusActive' + NodeStatusDeleted = 'NodeStatusDeleted' + NodeStatusBanned = 'NodeStatusBanned' + + +class NodeVersionIdentifier(BaseModel): + node_id: str = Field(..., description='The unique identifier of the node') + version: str = Field(..., description='The version of the node') + + +class NodeVersionStatus(str, Enum): + NodeVersionStatusActive = 'NodeVersionStatusActive' + NodeVersionStatusDeleted = 'NodeVersionStatusDeleted' + NodeVersionStatusBanned = 'NodeVersionStatusBanned' + NodeVersionStatusPending = 'NodeVersionStatusPending' + NodeVersionStatusFlagged = 'NodeVersionStatusFlagged' + + +class NodeVersionUpdateRequest(BaseModel): + changelog: Optional[str] = Field( + None, description='The changelog describing the version changes.' + ) + deprecated: Optional[bool] = Field( + None, description='Whether the version is deprecated.' + ) + + class Moderation(str, Enum): low = 'low' auto = 'auto' @@ -1723,38 +2115,57 @@ class Object(str, Enum): response = 'response' -class Status6(str, Enum): +class Status7(str, Enum): completed = 'completed' failed = 'failed' in_progress = 'in_progress' incomplete = 'incomplete' -class Type13(str, Enum): +class Type14(str, Enum): output_audio = 'output_audio' class OutputAudioContent(BaseModel): data: str = Field(..., description='Base64-encoded audio data') transcript: str = Field(..., description='Transcript of the audio') - type: Type13 = Field(..., description='The type of output content') + type: Type14 = Field(..., description='The type of output content') class Role4(str, Enum): assistant = 'assistant' -class Type14(str, Enum): +class Type15(str, Enum): message = 'message' -class Type15(str, Enum): +class Type16(str, Enum): output_text = 'output_text' class OutputTextContent(BaseModel): text: str = Field(..., description='The text content') - type: Type15 = Field(..., description='The type of output content') + type: Type16 = Field(..., description='The type of output content') + + +class PersonalAccessToken(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='[Output Only]The date and time the token was created.' + ) + description: Optional[str] = Field( + None, + description="Optional. A more detailed description of the token's intended use.", + ) + id: Optional[UUID] = Field(None, description='Unique identifier for the GitCommit') + name: Optional[str] = Field( + None, + description='Required. The name of the token. Can be a simple description.', + ) + token: Optional[str] = Field( + None, + description='[Output Only]. The personal access token. Only returned during creation.', + ) class AspectRatio1(RootModel[float]): @@ -1961,7 +2372,7 @@ class PixverseVideoResponse(BaseModel): Resp: Optional[Resp1] = None -class Status7(int, Enum): +class Status8(int, Enum): integer_1 = 1 integer_5 = 5 integer_6 = 6 @@ -1980,7 +2391,7 @@ class Resp2(BaseModel): resolution_ratio: Optional[int] = None seed: Optional[int] = None size: Optional[int] = None - status: Optional[Status7] = Field( + status: Optional[Status8] = Field( None, description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n', ) @@ -1994,6 +2405,17 @@ class PixverseVideoResultResponse(BaseModel): Resp: Optional[Resp2] = None +class PublisherStatus(str, Enum): + PublisherStatusActive = 'PublisherStatusActive' + PublisherStatusBanned = 'PublisherStatusBanned' + + +class PublisherUser(BaseModel): + email: Optional[str] = Field(None, description='The email address for this user.') + id: Optional[str] = Field(None, description='The unique id for this user.') + name: Optional[str] = Field(None, description='The name for this user.') + + class RgbItem(RootModel[int]): root: int = Field(..., ge=0, le=255) @@ -2020,13 +2442,13 @@ class ReasoningEffort(str, Enum): high = 'high' -class Status8(str, Enum): +class Status9(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' -class Type16(str, Enum): +class Type17(str, Enum): summary_text = 'summary_text' @@ -2035,12 +2457,12 @@ class SummaryItem(BaseModel): ..., description='A short summary of the reasoning used by the model when generating\nthe response.\n', ) - type: Type16 = Field( + type: Type17 = Field( ..., description='The type of the object. Always `summary_text`.\n' ) -class Type17(str, Enum): +class Type18(str, Enum): reasoning = 'reasoning' @@ -2048,16 +2470,31 @@ class ReasoningItem(BaseModel): id: str = Field( ..., description='The unique identifier of the reasoning content.\n' ) - status: Optional[Status8] = Field( + status: Optional[Status9] = Field( None, description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) summary: List[SummaryItem] = Field(..., description='Reasoning text contents.\n') - type: Type17 = Field( + type: Type18 = Field( ..., description='The type of the object. Always `reasoning`.\n' ) +class RecraftImageColor(BaseModel): + rgb: Optional[List[int]] = None + std: Optional[List[float]] = None + weight: Optional[float] = None + + +class RecraftImageFeatures(BaseModel): + nsfw_score: Optional[float] = None + + +class RecraftImageFormat(str, Enum): + webp = 'webp' + png = 'png' + + class Controls(BaseModel): artistic_level: Optional[int] = Field( None, @@ -2111,12 +2548,143 @@ class RecraftImageGenerationResponse(BaseModel): data: List[Datum3] = Field(..., description='Array of generated image information') +class RecraftImageStyle(str, Enum): + digital_illustration = 'digital_illustration' + icon = 'icon' + realistic_image = 'realistic_image' + vector_illustration = 'vector_illustration' + + +class RecraftImageSubStyle(str, Enum): + field_2d_art_poster = '2d_art_poster' + field_3d = '3d' + field_80s = '80s' + glow = 'glow' + grain = 'grain' + hand_drawn = 'hand_drawn' + infantile_sketch = 'infantile_sketch' + kawaii = 'kawaii' + pixel_art = 'pixel_art' + psychedelic = 'psychedelic' + seamless = 'seamless' + voxel = 'voxel' + watercolor = 'watercolor' + broken_line = 'broken_line' + colored_outline = 'colored_outline' + colored_shapes = 'colored_shapes' + colored_shapes_gradient = 'colored_shapes_gradient' + doodle_fill = 'doodle_fill' + doodle_offset_fill = 'doodle_offset_fill' + offset_fill = 'offset_fill' + outline = 'outline' + outline_gradient = 'outline_gradient' + uneven_fill = 'uneven_fill' + field_70s = '70s' + cartoon = 'cartoon' + doodle_line_art = 'doodle_line_art' + engraving = 'engraving' + flat_2 = 'flat_2' + kawaii_1 = 'kawaii' + line_art = 'line_art' + linocut = 'linocut' + seamless_1 = 'seamless' + b_and_w = 'b_and_w' + enterprise = 'enterprise' + hard_flash = 'hard_flash' + hdr = 'hdr' + motion_blur = 'motion_blur' + natural_light = 'natural_light' + studio_portrait = 'studio_portrait' + line_circuit = 'line_circuit' + field_2d_art_poster_2 = '2d_art_poster_2' + engraving_color = 'engraving_color' + flat_air_art = 'flat_air_art' + hand_drawn_outline = 'hand_drawn_outline' + handmade_3d = 'handmade_3d' + stickers_drawings = 'stickers_drawings' + plastic = 'plastic' + pictogram = 'pictogram' + + +class RecraftResponseFormat(str, Enum): + url = 'url' + b64_json = 'b64_json' + + +class RecraftTextLayoutItem(BaseModel): + bbox: List[List[float]] + text: str + + +class RecraftTransformModel(str, Enum): + refm1 = 'refm1' + recraft20b = 'recraft20b' + recraftv2 = 'recraftv2' + recraftv3 = 'recraftv3' + flux1_1pro = 'flux1_1pro' + flux1dev = 'flux1dev' + imagen3 = 'imagen3' + hidream_i1_dev = 'hidream_i1_dev' + + +class RecraftUserControls(BaseModel): + artistic_level: Optional[int] = None + background_color: Optional[RecraftImageColor] = None + colors: Optional[List[RecraftImageColor]] = None + no_text: Optional[bool] = None + + +class Attention(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + + +class Project(str, Enum): + comfyui = 'comfyui' + comfyui_frontend = 'comfyui_frontend' + desktop = 'desktop' + + +class ReleaseNote(BaseModel): + attention: Attention = Field( + ..., description='The attention level for this release' + ) + content: str = Field( + ..., description='The content of the release note in markdown format' + ) + id: int = Field(..., description='Unique identifier for the release note') + project: Project = Field( + ..., description='The project this release note belongs to' + ) + published_at: datetime = Field( + ..., description='When the release note was published' + ) + version: str = Field(..., description='The version of the release') + + class RenderingSpeed(str, Enum): BALANCED = 'BALANCED' TURBO = 'TURBO' QUALITY = 'QUALITY' +class Type19(str, Enum): + response_completed = 'response.completed' + + +class Type20(str, Enum): + response_content_part_added = 'response.content_part.added' + + +class Type21(str, Enum): + response_content_part_done = 'response.content_part.done' + + +class Type22(str, Enum): + response_created = 'response.created' + + class ResponseErrorCode(str, Enum): server_error = 'server_error' rate_limit_exceeded = 'rate_limit_exceeded' @@ -2138,12 +2706,27 @@ class ResponseErrorCode(str, Enum): image_file_not_found = 'image_file_not_found' -class Type18(str, Enum): +class Type23(str, Enum): + error = 'error' + + +class ResponseErrorEvent(BaseModel): + code: str = Field(..., description='The error code.\n') + message: str = Field(..., description='The error message.\n') + param: str = Field(..., description='The error parameter.\n') + type: Type23 = Field(..., description='The type of the event. Always `error`.\n') + + +class Type24(str, Enum): + response_failed = 'response.failed' + + +class Type25(str, Enum): json_object = 'json_object' class ResponseFormatJsonObject(BaseModel): - type: Type18 = Field( + type: Type25 = Field( ..., description='The type of response format being defined. Always `json_object`.', ) @@ -2156,16 +2739,32 @@ class ResponseFormatJsonSchemaSchema(BaseModel): ) -class Type19(str, Enum): +class Type26(str, Enum): text = 'text' class ResponseFormatText(BaseModel): - type: Type19 = Field( + type: Type26 = Field( ..., description='The type of response format being defined. Always `text`.' ) +class Type27(str, Enum): + response_in_progress = 'response.in_progress' + + +class Type28(str, Enum): + response_incomplete = 'response.incomplete' + + +class Type29(str, Enum): + response_output_item_added = 'response.output_item.added' + + +class Type30(str, Enum): + response_output_item_done = 'response.output_item.done' + + class Truncation1(str, Enum): auto = 'auto' disabled = 'disabled' @@ -2200,10 +2799,6 @@ class Rodin3DCheckStatusRequest(BaseModel): ) -class Rodin3DCheckStatusResponse(BaseModel): - pass - - class Rodin3DDownloadRequest(BaseModel): task_uuid: str = Field(..., description='Task UUID') @@ -2235,6 +2830,13 @@ class RodinResourceItem(BaseModel): url: Optional[str] = Field(None, description='Download url') +class RodinStatusOptions(str, Enum): + Done = 'Done' + Failed = 'Failed' + Generating = 'Generating' + Waiting = 'Waiting' + + class RodinTierType(str, Enum): Regular = 'Regular' Sketch = 'Sketch' @@ -2325,6 +2927,7 @@ class RunwayTextToImageAspectRatioEnum(str, Enum): field_1808_768 = '1808:768' field_2112_912 = '2112:912' + class Model4(str, Enum): gen4_image = 'gen4_image' @@ -2350,6 +2953,38 @@ class RunwayTextToImageResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') +class Name(str, Enum): + content_moderation = 'content_moderation' + + +class StabilityContentModerationResponse(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: Name = Field( + ..., + description='Our content moderation system has flagged some part of your request and subsequently denied it. You were not charged for this request. While this may at times be frustrating, it is necessary to maintain the integrity of our platform and ensure a safe experience for all users. If you would like to provide feedback, please use the [Support Form](https://kb.stability.ai/knowledge-base/kb-tickets/new).', + ) + + +class StabilityCreativity(RootModel[float]): + root: float = Field( + ..., + description='Controls the likelihood of creating additional details not heavily conditioned by the init image.', + ge=0.2, + le=0.5, + ) + + class StabilityError(BaseModel): errors: List[str] = Field( ..., @@ -2371,7 +3006,17 @@ class StabilityError(BaseModel): ) -class Status9(str, Enum): +class StabilityGenerationID(RootModel[str]): + root: str = Field( + ..., + description='The `id` of a generation, typically used for async generations, that can be used to check the status of the generation or retrieve the result.', + examples=['a6dc6c6e20acda010fe14d71f180658f2896ed9b4ec25aa99a6ff06c796987c4'], + max_length=64, + min_length=64, + ) + + +class Status10(str, Enum): in_progress = 'in-progress' @@ -2379,10 +3024,860 @@ class StabilityGetResultResponse202(BaseModel): id: Optional[str] = Field( None, description='The ID of the generation result.', examples=[1234567890] ) - status: Optional[Status9] = None + status: Optional[Status10] = None -class Type20(str, Enum): +class AspectRatio3(str, Enum): + field_21_9 = '21:9' + field_16_9 = '16:9' + field_3_2 = '3:2' + field_5_4 = '5:4' + field_1_1 = '1:1' + field_4_5 = '4:5' + field_2_3 = '2:3' + field_9_16 = '9:16' + field_9_21 = '9:21' + + +class Mode(str, Enum): + text_to_image = 'text-to-image' + image_to_image = 'image-to-image' + + +class Model5(str, Enum): + sd3_5_large = 'sd3.5-large' + sd3_5_large_turbo = 'sd3.5-large-turbo' + sd3_5_medium = 'sd3.5-medium' + + +class OutputFormat3(str, Enum): + png = 'png' + jpeg = 'jpeg' + + +class StylePreset(str, Enum): + enhance = 'enhance' + anime = 'anime' + photographic = 'photographic' + digital_art = 'digital-art' + comic_book = 'comic-book' + fantasy_art = 'fantasy-art' + line_art = 'line-art' + analog_film = 'analog-film' + neon_punk = 'neon-punk' + isometric = 'isometric' + low_poly = 'low-poly' + origami = 'origami' + modeling_compound = 'modeling-compound' + cinematic = 'cinematic' + field_3d_model = '3d-model' + pixel_art = 'pixel-art' + tile_texture = 'tile-texture' + + +class StabilityImageGenerationSD3Request(BaseModel): + aspect_ratio: Optional[AspectRatio3] = Field( + '1:1', + description='Controls the aspect ratio of the generated image. Defaults to 1:1.\n\n> **Important:** This parameter is only valid for **text-to-image** requests.', + ) + cfg_scale: Optional[float] = Field( + None, + description='How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt). The _Large_ and _Medium_ models use a default of `4`. The _Turbo_ model uses a default of `1`.', + ge=1.0, + le=10.0, + ) + image: Optional[StrictBytes] = Field( + None, + description='The image to use as the starting point for the generation.\n\nSupported formats:\n\n\n\n - jpeg\n - png\n - webp\n\nSupported dimensions:\n\n\n\n - Every side must be at least 64 pixels\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ) + mode: Optional[Mode] = Field( + 'text-to-image', + description='Controls whether this is a text-to-image or image-to-image generation, which affects which parameters are required:\n- **text-to-image** requires only the `prompt` parameter\n- **image-to-image** requires the `prompt`, `image`, and `strength` parameters', + title='GenerationMode', + ) + model: Optional[Model5] = Field( + 'sd3.5-large', + description='The model to use for generation.\n\n- `sd3.5-large` requires 6.5 credits per generation\n- `sd3.5-large-turbo` requires 4 credits per generation\n- `sd3.5-medium` requires 3.5 credits per generation\n- As of the April 17, 2025, `sd3-large`, `sd3-large-turbo` and `sd3-medium`\n\n\n\n are re-routed to their `sd3.5-[model version]` equivalent, at the same price.', + ) + negative_prompt: Optional[str] = Field( + None, + description='Keywords of what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat3] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description='What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.', + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + strength: Optional[float] = Field( + None, + description='Sometimes referred to as _denoising_, this parameter controls how much influence the\n`image` parameter has on the generated image. A value of 0 would yield an image that\nis identical to the input. A value of 1 would be as if you passed in no image at all.\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ge=0.0, + le=1.0, + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + + +class FinishReason(str, Enum): + SUCCESS = 'SUCCESS' + CONTENT_FILTERED = 'CONTENT_FILTERED' + + +class StabilityImageGenrationSD3Response200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationSD3Response400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class OutputFormat4(str, Enum): + jpeg = 'jpeg' + png = 'png' + webp = 'webp' + + +class StabilityImageGenrationUpscaleConservativeRequest(BaseModel): + creativity: Optional[StabilityCreativity] = Field( + default_factory=lambda: StabilityCreativity.model_validate(0.35) + ) + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 9,437,184 pixels\n- The aspect ratio must be between 1:2.5 and 2.5:1', + examples=['./some/image.png'], + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeRequest(BaseModel): + creativity: Optional[float] = Field( + 0.3, + description='Indicates how creative the model should be when upscaling an image.\nHigher values will result in more details being added to the image during upscaling.', + ge=0.1, + le=0.5, + ) + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + + +class StabilityImageGenrationUpscaleCreativeResponse200(BaseModel): + id: StabilityGenerationID + + +class StabilityImageGenrationUpscaleCreativeResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastRequest(BaseModel): + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Width must be between 32 and 1,536 pixels\n- Height must be between 32 and 1,536 pixels\n- Total pixel count must be between 1,024 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + + +class StabilityImageGenrationUpscaleFastResponse200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleFastResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityStabilityClientID(RootModel[str]): + root: str = Field( + ..., + description='The name of your application, used to help us communicate app-specific debugging or moderation issues to you.', + examples=['my-awesome-app'], + max_length=256, + ) + + +class StabilityStabilityClientUserID(RootModel[str]): + root: str = Field( + ..., + description='A unique identifier for your end user. Used to help us communicate user-specific debugging or moderation issues to you. Feel free to obfuscate this value to protect user privacy.', + examples=['DiscordUser#9999'], + max_length=256, + ) + + +class StabilityStabilityClientVersion(RootModel[str]): + root: str = Field( + ..., + description='The version of your application, used to help us communicate version-specific debugging or moderation issues to you.', + examples=['1.2.1'], + max_length=256, + ) + + +class StorageFile(BaseModel): + file_path: Optional[str] = Field(None, description='Path to the file in storage') + id: Optional[UUID] = Field( + None, description='Unique identifier for the storage file' + ) + public_url: Optional[str] = Field(None, description='Public URL') + + +class StripeAddress(BaseModel): + city: Optional[str] = None + country: Optional[str] = None + line1: Optional[str] = None + line2: Optional[str] = None + postal_code: Optional[str] = None + state: Optional[str] = None + + +class StripeAmountDetails(BaseModel): + tip: Optional[Dict[str, Any]] = None + + +class StripeBillingDetails(BaseModel): + address: Optional[StripeAddress] = None + email: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tax_id: Optional[Any] = None + + +class Checks(BaseModel): + address_line1_check: Optional[Any] = None + address_postal_code_check: Optional[Any] = None + cvc_check: Optional[str] = None + + +class ExtendedAuthorization(BaseModel): + status: Optional[str] = None + + +class IncrementalAuthorization(BaseModel): + status: Optional[str] = None + + +class Multicapture(BaseModel): + status: Optional[str] = None + + +class NetworkToken(BaseModel): + used: Optional[bool] = None + + +class Overcapture(BaseModel): + maximum_amount_capturable: Optional[int] = None + status: Optional[str] = None + + +class StripeCardDetails(BaseModel): + amount_authorized: Optional[int] = None + authorization_code: Optional[Any] = None + brand: Optional[str] = None + checks: Optional[Checks] = None + country: Optional[str] = None + exp_month: Optional[int] = None + exp_year: Optional[int] = None + extended_authorization: Optional[ExtendedAuthorization] = None + fingerprint: Optional[str] = None + funding: Optional[str] = None + incremental_authorization: Optional[IncrementalAuthorization] = None + installments: Optional[Any] = None + last4: Optional[str] = None + mandate: Optional[Any] = None + multicapture: Optional[Multicapture] = None + network: Optional[str] = None + network_token: Optional[NetworkToken] = None + network_transaction_id: Optional[str] = None + overcapture: Optional[Overcapture] = None + regulated_status: Optional[str] = None + three_d_secure: Optional[Any] = None + wallet: Optional[Any] = None + + +class Object1(str, Enum): + charge = 'charge' + + +class Object2(str, Enum): + event = 'event' + + +class Type31(str, Enum): + payment_intent_succeeded = 'payment_intent.succeeded' + + +class StripeOutcome(BaseModel): + advice_code: Optional[Any] = None + network_advice_code: Optional[Any] = None + network_decline_code: Optional[Any] = None + network_status: Optional[str] = None + reason: Optional[Any] = None + risk_level: Optional[str] = None + risk_score: Optional[int] = None + seller_message: Optional[str] = None + type: Optional[str] = None + + +class Object3(str, Enum): + payment_intent = 'payment_intent' + + +class StripePaymentMethodDetails(BaseModel): + card: Optional[StripeCardDetails] = None + type: Optional[str] = None + + +class Card(BaseModel): + installments: Optional[Any] = None + mandate_options: Optional[Any] = None + network: Optional[Any] = None + request_three_d_secure: Optional[str] = None + + +class StripePaymentMethodOptions(BaseModel): + card: Optional[Card] = None + + +class StripeRefundList(BaseModel): + data: Optional[List[Dict[str, Any]]] = None + has_more: Optional[bool] = None + object: Optional[str] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class StripeRequestInfo(BaseModel): + id: Optional[str] = None + idempotency_key: Optional[str] = None + + +class StripeShipping(BaseModel): + address: Optional[StripeAddress] = None + carrier: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tracking_number: Optional[str] = None + + +class Type32(str, Enum): json_schema = 'json_schema' @@ -2400,19 +3895,19 @@ class TextResponseFormatJsonSchema(BaseModel): False, description='Whether to enable strict schema adherence when generating the output.\nIf set to true, the model will always follow the exact schema defined\nin the `schema` field. Only a subset of JSON Schema is supported when\n`strict` is `true`. To learn more, read the [Structured Outputs\nguide](/docs/guides/structured-outputs).\n', ) - type: Type20 = Field( + type: Type32 = Field( ..., description='The type of response format being defined. Always `json_schema`.', ) -class Type21(str, Enum): +class Type33(str, Enum): function = 'function' class ToolChoiceFunction(BaseModel): name: str = Field(..., description='The name of the function to call.') - type: Type21 = Field( + type: Type33 = Field( ..., description='For function calling, the type is always `function`.' ) @@ -2423,7 +3918,7 @@ class ToolChoiceOptions(str, Enum): required = 'required' -class Type22(str, Enum): +class Type34(str, Enum): file_search = 'file_search' web_search_preview = 'web_search_preview' computer_use_preview = 'computer_use_preview' @@ -2431,7 +3926,7 @@ class Type22(str, Enum): class ToolChoiceTypes(BaseModel): - type: Type22 = Field( + type: Type34 = Field( ..., description='The type of hosted tool the model should to use. Learn more about\n[built-in tools](/docs/guides/tools).\n\nAllowed values are:\n- `file_search`\n- `web_search_preview`\n- `computer_use_preview`\n', ) @@ -2499,9 +3994,9 @@ class TripoModelStyle(str, Enum): class TripoModelVersion(str, Enum): - V2_5 = 'v2.5-20250123' - V2_0 = 'v2.0-20240919' - V1_4 = 'v1.4-20240625' + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' class TripoMultiviewMode(str, Enum): @@ -2547,13 +4042,13 @@ class Code1(int, Enum): integer_0 = 0 -class Data8(BaseModel): +class Data9(BaseModel): task_id: str = Field(..., description='used for getTask') class TripoSuccessTask(BaseModel): code: Code1 - data: Data8 + data: Data9 class Topology(str, Enum): @@ -2570,7 +4065,7 @@ class Output(BaseModel): topology: Optional[Topology] = None -class Status10(str, Enum): +class Status11(str, Enum): queued = 'queued' running = 'running' success = 'success' @@ -2586,7 +4081,7 @@ class TripoTask(BaseModel): input: Dict[str, Any] output: Output progress: int = Field(..., ge=0, le=100) - status: Status10 + status: Status11 task_id: str type: str @@ -2650,6 +4145,18 @@ class TripoTypeTextureModel(str, Enum): texture_model = 'texture_model' +class User(BaseModel): + email: Optional[str] = Field(None, description='The email address for this user.') + id: Optional[str] = Field(None, description='The unique id for this user.') + isAdmin: Optional[bool] = Field( + None, description='Indicates if the user has admin privileges.' + ) + isApproved: Optional[bool] = Field( + None, description='Indicates if the user is approved.' + ) + name: Optional[str] = Field(None, description='The name for this user.') + + class Veo2GenVidPollRequest(BaseModel): operationName: str = Field( ..., @@ -2660,7 +4167,7 @@ class Veo2GenVidPollRequest(BaseModel): ) -class Error(BaseModel): +class Error1(BaseModel): code: Optional[int] = Field(None, description='Error code') message: Optional[str] = Field(None, description='Error message') @@ -2692,7 +4199,7 @@ class Response(BaseModel): class Veo2GenVidPollResponse(BaseModel): done: Optional[bool] = None - error: Optional[Error] = Field( + error: Optional[Error1] = Field( None, description='Error details if operation failed' ) name: Optional[str] = None @@ -2753,13 +4260,102 @@ class Veo2GenVidResponse(BaseModel): ) +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[List[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[List[Video]] = None + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) + + +class Image2(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image3(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance1(BaseModel): + image: Optional[Union[Image2, Image3]] = Field( + None, description='Optional image to guide video generation' + ) + prompt: str = Field(..., description='Text description of the video') + + +class Parameters1(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + + +class VeoGenVidRequest(BaseModel): + instances: Optional[List[Instance1]] = None + parameters: Optional[Parameters1] = None + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + class SearchContextSize(str, Enum): low = 'low' medium = 'medium' high = 'high' -class Type23(str, Enum): +class Type35(str, Enum): web_search_preview = 'web_search_preview' web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11' @@ -2775,30 +4371,348 @@ class WebSearchPreviewTool(BaseModel): ) -class Status11(str, Enum): +class Status12(str, Enum): in_progress = 'in_progress' searching = 'searching' completed = 'completed' failed = 'failed' -class Type24(str, Enum): +class Type36(str, Enum): web_search_call = 'web_search_call' class WebSearchToolCall(BaseModel): id: str = Field(..., description='The unique ID of the web search tool call.\n') - status: Status11 = Field( + status: Status12 = Field( ..., description='The status of the web search tool call.\n' ) - type: Type24 = Field( + type: Type36 = Field( ..., description='The type of the web search tool call. Always `web_search_call`.\n', ) -class CreateModelResponseProperties(ModelResponseProperties): - pass +class WorkflowRunStatus(str, Enum): + WorkflowRunStatusStarted = 'WorkflowRunStatusStarted' + WorkflowRunStatusFailed = 'WorkflowRunStatusFailed' + WorkflowRunStatusCompleted = 'WorkflowRunStatusCompleted' + + +class ActionJobResult(BaseModel): + action_job_id: Optional[str] = Field( + None, description='Identifier of the job this result belongs to' + ) + action_run_id: Optional[str] = Field( + None, description='Identifier of the run this result belongs to' + ) + author: Optional[str] = Field(None, description='The author of the commit') + avg_vram: Optional[int] = Field( + None, description='The average VRAM used by the job' + ) + branch_name: Optional[str] = Field( + None, description='Name of the relevant git branch' + ) + comfy_run_flags: Optional[str] = Field( + None, description='The comfy run flags. E.g. `--low-vram`' + ) + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_id: Optional[str] = Field(None, description='The ID of the commit') + commit_message: Optional[str] = Field(None, description='The message of the commit') + commit_time: Optional[int] = Field( + None, description='The Unix timestamp when the commit was made' + ) + cuda_version: Optional[str] = Field(None, description='CUDA version used') + end_time: Optional[int] = Field( + None, description='The end time of the job as a Unix timestamp.' + ) + git_repo: Optional[str] = Field(None, description='The repository name') + id: Optional[UUID] = Field(None, description='Unique identifier for the job result') + job_trigger_user: Optional[str] = Field( + None, description='The user who triggered the job.' + ) + machine_stats: Optional[MachineStats] = None + operating_system: Optional[str] = Field(None, description='Operating system used') + peak_vram: Optional[int] = Field(None, description='The peak VRAM used by the job') + pr_number: Optional[str] = Field(None, description='The pull request number') + python_version: Optional[str] = Field(None, description='PyTorch version used') + pytorch_version: Optional[str] = Field(None, description='PyTorch version used') + start_time: Optional[int] = Field( + None, description='The start time of the job as a Unix timestamp.' + ) + status: Optional[WorkflowRunStatus] = None + storage_file: Optional[StorageFile] = None + workflow_name: Optional[str] = Field(None, description='Name of the workflow') + + +class BFLCannyInputs(BaseModel): + canny_high_threshold: Optional[CannyHighThreshold] = Field( + default_factory=lambda: CannyHighThreshold.model_validate(200), + description='High threshold for Canny edge detection', + title='Canny High Threshold', + ) + canny_low_threshold: Optional[CannyLowThreshold] = Field( + default_factory=lambda: CannyLowThreshold.model_validate(50), + description='Low threshold for Canny edge detection', + title='Canny Low Threshold', + ) + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input if no preprocessed image is provided', + title='Control Image', + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(30), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLDepthInputs(BaseModel): + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input', + title='Control Image', + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(15), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxProExpandInputs(BaseModel): + bottom: Optional[Bottom] = Field( + 0, + description='Number of pixels to expand at the bottom of the image', + title='Bottom', + ) + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to expand.', + title='Image', + ) + left: Optional[Left] = Field( + 0, + description='Number of pixels to expand on the left side of the image', + title='Left', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + right: Optional[Right] = Field( + 0, + description='Number of pixels to expand on the right side of the image', + title='Right', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + top: Optional[Top] = Field( + 0, description='Number of pixels to expand at the top of the image', title='Top' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxProFillInputs(BaseModel): + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.', + title='Image', + ) + mask: Optional[str] = Field( + None, + description='A Base64-encoded string representing a mask for the areas you want to modify in the image. The mask should be the same dimensions as the image and in black and white. Black areas (0%) indicate no modification, while white areas (100%) specify areas for inpainting. Optional if you provide an alpha mask in the original image. Validation: The endpoint verifies that the dimensions of the mask match the original image.', + title='Mask', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the inpainting process, allowing you to specify features, styles, or modifications for the masked area.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLHTTPValidationError(BaseModel): + detail: Optional[List[BFLValidationError]] = Field(None, title='Detail') + + +class BulkNodeVersionsRequest(BaseModel): + node_versions: List[NodeVersionIdentifier] = Field( + ..., description='List of node ID and version pairs to retrieve' + ) + + +CreateModelResponseProperties = ModelResponseProperties class GeminiInlineData(BaseModel): @@ -2841,6 +4755,125 @@ class GeminiSystemInstructionContent(BaseModel): ) +class GeminiUsageMetadata(BaseModel): + cachedContentTokenCount: Optional[int] = Field( + None, + description='Output only. Number of tokens in the cached part in the input (the cached content).', + ) + candidatesTokenCount: Optional[int] = Field( + None, description='Number of tokens in the response(s).' + ) + candidatesTokensDetails: Optional[List[ModalityTokenCount]] = Field( + None, description='Breakdown of candidate tokens by modality.' + ) + promptTokenCount: Optional[int] = Field( + None, + description='Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.', + ) + promptTokensDetails: Optional[List[ModalityTokenCount]] = Field( + None, description='Breakdown of prompt tokens by modality.' + ) + thoughtsTokenCount: Optional[int] = Field( + None, description='Number of tokens present in thoughts output.' + ) + toolUsePromptTokenCount: Optional[int] = Field( + None, description='Number of tokens present in tool-use prompt(s).' + ) + + +class GithubInstallation(BaseModel): + access_tokens_url: str = Field(..., description='The API URL for access tokens') + account: GithubUser + app_id: int = Field(..., description='The GitHub App ID') + created_at: datetime = Field(..., description='When the installation was created') + events: List[str] = Field( + ..., description='The events the installation subscribes to' + ) + html_url: str = Field(..., description='The HTML URL of the installation') + id: int = Field(..., description='The installation ID') + permissions: Dict[str, Any] = Field(..., description='The installation permissions') + repositories_url: str = Field(..., description='The API URL for repositories') + repository_selection: RepositorySelection = Field( + ..., description='Repository selection for the installation' + ) + single_file_name: Optional[str] = Field( + None, description='The single file name if applicable' + ) + target_id: int = Field(..., description='The target ID') + target_type: str = Field(..., description='The target type') + updated_at: datetime = Field( + ..., description='When the installation was last updated' + ) + + +class GithubReleaseAsset(BaseModel): + browser_download_url: str = Field(..., description='The browser download URL') + content_type: str = Field(..., description='The content type of the asset') + created_at: datetime = Field(..., description='When the asset was created') + download_count: int = Field(..., description='The number of downloads') + id: int = Field(..., description='The asset ID') + label: Optional[str] = Field(None, description='The label of the asset') + name: str = Field(..., description='The name of the asset') + node_id: str = Field(..., description='The asset node ID') + size: int = Field(..., description='The size of the asset in bytes') + state: State = Field(..., description='The state of the asset') + updated_at: datetime = Field(..., description='When the asset was last updated') + uploader: GithubUser + + +class Release(BaseModel): + assets: List[GithubReleaseAsset] = Field(..., description='Array of release assets') + assets_url: Optional[str] = Field(None, description='The URL to the release assets') + author: GithubUser + body: Optional[str] = Field(None, description='The release notes/body') + created_at: datetime = Field(..., description='When the release was created') + draft: bool = Field(..., description='Whether the release is a draft') + html_url: str = Field(..., description='The HTML URL of the release') + id: int = Field(..., description='The ID of the release') + name: Optional[str] = Field(None, description='The name of the release') + node_id: str = Field(..., description='The node ID of the release') + prerelease: bool = Field(..., description='Whether the release is a prerelease') + published_at: Optional[datetime] = Field( + None, description='When the release was published' + ) + tag_name: str = Field(..., description='The tag name of the release') + tarball_url: str = Field(..., description='URL to the tarball') + target_commitish: str = Field( + ..., description='The branch or commit the release was created from' + ) + upload_url: Optional[str] = Field( + None, description='The URL to upload release assets' + ) + url: str = Field(..., description='The API URL of the release') + zipball_url: str = Field(..., description='URL to the zipball') + + +class GithubRepository(BaseModel): + clone_url: str = Field(..., description='The clone URL of the repository') + created_at: datetime = Field(..., description='When the repository was created') + default_branch: str = Field(..., description='The default branch of the repository') + description: Optional[str] = Field(None, description='The repository description') + fork: bool = Field(..., description='Whether the repository is a fork') + full_name: str = Field( + ..., description='The full name of the repository (owner/repo)' + ) + git_url: str = Field(..., description='The git URL of the repository') + html_url: str = Field(..., description='The HTML URL of the repository') + id: int = Field(..., description='The repository ID') + name: str = Field(..., description='The name of the repository') + node_id: str = Field(..., description='The repository node ID') + owner: GithubUser + private: bool = Field(..., description='Whether the repository is private') + pushed_at: datetime = Field( + ..., description='When the repository was last pushed to' + ) + ssh_url: str = Field(..., description='The SSH URL of the repository') + updated_at: datetime = Field( + ..., description='When the repository was last updated' + ) + url: str = Field(..., description='The API URL of the repository') + + class IdeogramV3EditRequest(BaseModel): color_palette: Optional[IdeogramColorPalette] = None image: Optional[StrictBytes] = Field( @@ -3276,6 +5309,52 @@ class MoonvalleyTextToImageRequest(BaseModel): webhook_url: Optional[str] = None +class NodeVersion(BaseModel): + changelog: Optional[str] = Field( + None, description='Summary of changes made in this version' + ) + comfy_node_extract_status: Optional[str] = Field( + None, description='The status of comfy node extraction process.' + ) + createdAt: Optional[datetime] = Field( + None, description='The date and time the version was created.' + ) + dependencies: Optional[List[str]] = Field( + None, description='A list of pip dependencies required by the node.' + ) + deprecated: Optional[bool] = Field( + None, description='Indicates if this version is deprecated.' + ) + downloadUrl: Optional[str] = Field( + None, description='[Output Only] URL to download this version of the node' + ) + id: Optional[str] = None + node_id: Optional[str] = Field( + None, description='The unique identifier of the node.' + ) + status: Optional[NodeVersionStatus] = None + status_reason: Optional[str] = Field( + None, description='The reason for the status change.' + ) + supported_accelerators: Optional[List[str]] = Field( + None, + description='List of accelerators (e.g. CUDA, DirectML, ROCm) that this node supports', + ) + supported_comfyui_frontend_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI frontend' + ) + supported_comfyui_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI' + ) + supported_os: Optional[List[str]] = Field( + None, description='List of operating systems that this node supports' + ) + version: Optional[str] = Field( + None, + description='The version identifier, following semantic versioning. Must be unique for the node.', + ) + + class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]): root: Union[OutputTextContent, OutputAudioContent] @@ -3283,7 +5362,7 @@ class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]): class OutputMessage(BaseModel): content: List[OutputContent] = Field(..., description='The content of the message') role: Role4 = Field(..., description='The role of the message') - type: Type14 = Field(..., description='The type of output item') + type: Type15 = Field(..., description='The type of output item') class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): @@ -3333,6 +5412,16 @@ class PikaHTTPValidationError(BaseModel): detail: Optional[List[PikaValidationError]] = Field(None, title='Detail') +class PublisherMember(BaseModel): + id: Optional[str] = Field( + None, description='The unique identifier for the publisher member.' + ) + role: Optional[str] = Field( + None, description='The role of the user in the publisher.' + ) + user: Optional[PublisherUser] = None + + class Reasoning(BaseModel): effort: Optional[ReasoningEffort] = 'medium' generate_summary: Optional[GenerateSummary] = Field( @@ -3345,13 +5434,88 @@ class Reasoning(BaseModel): ) +class RecraftImage(BaseModel): + b64_json: Optional[str] = None + features: Optional[RecraftImageFeatures] = None + image_id: UUID + revised_prompt: Optional[str] = None + url: Optional[str] = None + + +class RecraftProcessImageRequest(BaseModel): + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + response_format: Optional[RecraftResponseFormat] = None + + +class RecraftProcessImageResponse(BaseModel): + created: int + credits: int + image: RecraftImage + + +class RecraftTextLayout(RootModel[List[RecraftTextLayoutItem]]): + root: List[RecraftTextLayoutItem] + + +class RecraftTransformImageWithMaskRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + mask: StrictBytes + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + response_format: Optional[RecraftResponseFormat] = None + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class ResponseContentPartAddedEvent(BaseModel): + content_index: int = Field( + ..., description='The index of the content part that was added.' + ) + item_id: str = Field( + ..., description='The ID of the output item that the content part was added to.' + ) + output_index: int = Field( + ..., + description='The index of the output item that the content part was added to.', + ) + part: OutputContent + type: Type20 = Field( + ..., description='The type of the event. Always `response.content_part.added`.' + ) + + +class ResponseContentPartDoneEvent(BaseModel): + content_index: int = Field( + ..., description='The index of the content part that is done.' + ) + item_id: str = Field( + ..., description='The ID of the output item that the content part was added to.' + ) + output_index: int = Field( + ..., + description='The index of the output item that the content part was added to.', + ) + part: OutputContent + type: Type21 = Field( + ..., description='The type of the event. Always `response.content_part.done`.' + ) + + class ResponseError(BaseModel): code: ResponseErrorCode message: str = Field(..., description='A human-readable description of the error.') class Rodin3DDownloadResponse(BaseModel): - list: Optional[RodinResourceItem] = None + list: Optional[List[RodinResourceItem]] = None class Rodin3DGenerateRequest(BaseModel): @@ -3371,6 +5535,11 @@ class Rodin3DGenerateResponse(BaseModel): uuid: Optional[str] = Field(None, description='Task UUID') +class RodinCheckStatusJobItem(BaseModel): + status: Optional[RodinStatusOptions] = None + uuid: Optional[str] = Field(None, description='sub uuid') + + class RunwayImageToVideoRequest(BaseModel): duration: RunwayDurationEnum model: RunwayModelEnum @@ -3384,6 +5553,109 @@ class RunwayImageToVideoRequest(BaseModel): ) +class StripeCharge(BaseModel): + amount: Optional[int] = None + amount_captured: Optional[int] = None + amount_refunded: Optional[int] = None + application: Optional[str] = None + application_fee: Optional[str] = None + application_fee_amount: Optional[int] = None + balance_transaction: Optional[str] = None + billing_details: Optional[StripeBillingDetails] = None + calculated_statement_descriptor: Optional[str] = None + captured: Optional[bool] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + destination: Optional[Any] = None + dispute: Optional[Any] = None + disputed: Optional[bool] = None + failure_balance_transaction: Optional[Any] = None + failure_code: Optional[Any] = None + failure_message: Optional[Any] = None + fraud_details: Optional[Dict[str, Any]] = None + id: Optional[str] = None + invoice: Optional[Any] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + object: Optional[Object1] = None + on_behalf_of: Optional[Any] = None + order: Optional[Any] = None + outcome: Optional[StripeOutcome] = None + paid: Optional[bool] = None + payment_intent: Optional[str] = None + payment_method: Optional[str] = None + payment_method_details: Optional[StripePaymentMethodDetails] = None + radar_options: Optional[Dict[str, Any]] = None + receipt_email: Optional[str] = None + receipt_number: Optional[str] = None + receipt_url: Optional[str] = None + refunded: Optional[bool] = None + refunds: Optional[StripeRefundList] = None + review: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + source_transfer: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + +class StripeChargeList(BaseModel): + data: Optional[List[StripeCharge]] = None + has_more: Optional[bool] = None + object: Optional[str] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class StripePaymentIntent(BaseModel): + amount: Optional[int] = None + amount_capturable: Optional[int] = None + amount_details: Optional[StripeAmountDetails] = None + amount_received: Optional[int] = None + application: Optional[str] = None + application_fee_amount: Optional[int] = None + automatic_payment_methods: Optional[Any] = None + canceled_at: Optional[int] = None + cancellation_reason: Optional[str] = None + capture_method: Optional[str] = None + charges: Optional[StripeChargeList] = None + client_secret: Optional[str] = None + confirmation_method: Optional[str] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + id: Optional[str] = None + invoice: Optional[str] = None + last_payment_error: Optional[Any] = None + latest_charge: Optional[str] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + next_action: Optional[Any] = None + object: Optional[Object3] = None + on_behalf_of: Optional[Any] = None + payment_method: Optional[str] = None + payment_method_configuration_details: Optional[Any] = None + payment_method_options: Optional[StripePaymentMethodOptions] = None + payment_method_types: Optional[List[str]] = None + processing: Optional[Any] = None + receipt_email: Optional[str] = None + review: Optional[Any] = None + setup_future_usage: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + class TextResponseFormatConfiguration( RootModel[ Union[ @@ -3411,6 +5683,22 @@ class Tool( ] = Field(..., discriminator='type') +class BulkNodeVersionResult(BaseModel): + error_message: Optional[str] = Field( + None, + description='Error message if retrieval failed (only present if status is error)', + ) + identifier: NodeVersionIdentifier + node_version: Optional[NodeVersion] = None + status: Status = Field(..., description='Status of the retrieval operation') + + +class BulkNodeVersionsResponse(BaseModel): + node_versions: List[BulkNodeVersionResult] = Field( + ..., description='List of retrieved node versions with their status' + ) + + class EasyInputMessage(BaseModel): content: Union[str, InputMessageContentList] = Field( ..., @@ -3439,6 +5727,16 @@ class GeminiGenerateContentRequest(BaseModel): videoMetadata: Optional[GeminiVideoMetadata] = None +class GithubReleaseWebhook(BaseModel): + action: Action = Field(..., description='The action performed on the release') + enterprise: Optional[GithubEnterprise] = None + installation: Optional[GithubInstallation] = None + organization: Optional[GithubOrganization] = None + release: Release = Field(..., description='The release object') + repository: GithubRepository + sender: GithubUser + + class ImagenGenerateImageRequest(BaseModel): instances: List[ImagenImageGenerationInstance] parameters: ImagenImageGenerationParameters @@ -3447,8 +5745,8 @@ class ImagenGenerateImageRequest(BaseModel): class InputMessage(BaseModel): content: Optional[InputMessageContentList] = None role: Optional[Role3] = None - status: Optional[Status2] = None - type: Optional[Type9] = None + status: Optional[Status3] = None + type: Optional[Type10] = None class Item( @@ -3519,6 +5817,70 @@ class OutputItem( ] +class Publisher(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='The date and time the publisher was created.' + ) + description: Optional[str] = None + id: Optional[str] = Field( + None, + description="The unique identifier for the publisher. It's akin to a username. Should be lowercase.", + ) + logo: Optional[str] = Field(None, description="URL to the publisher's logo.") + members: Optional[List[PublisherMember]] = Field( + None, description='A list of members in the publisher.' + ) + name: Optional[str] = None + source_code_repo: Optional[str] = None + status: Optional[PublisherStatus] = None + support: Optional[str] = None + website: Optional[str] = None + + +class RecraftGenerateImageResponse(BaseModel): + created: int + credits: int + data: List[RecraftImage] + + +class RecraftImageToImageRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + controls: Optional[RecraftUserControls] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + response_format: Optional[RecraftResponseFormat] = None + strength: float + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class ResponseOutputItemAddedEvent(BaseModel): + item: OutputItem + output_index: int = Field( + ..., description='The index of the output item that was added.\n' + ) + type: Type29 = Field( + ..., description='The type of the event. Always `response.output_item.added`.\n' + ) + + +class ResponseOutputItemDoneEvent(BaseModel): + item: OutputItem + output_index: int = Field( + ..., description='The index of the output item that was marked done.\n' + ) + type: Type30 = Field( + ..., description='The type of the event. Always `response.output_item.done`.\n' + ) + + class Text(BaseModel): format: Optional[TextResponseFormatConfiguration] = None @@ -3552,6 +5914,28 @@ class ResponseProperties(BaseModel): ) +class Rodin3DCheckStatusResponse(BaseModel): + jobs: Optional[List[RodinCheckStatusJobItem]] = Field( + None, description='Details for the generation status.' + ) + + +class Data8(BaseModel): + object: Optional[StripePaymentIntent] = None + + +class StripeEvent(BaseModel): + api_version: Optional[str] = None + created: Optional[int] = None + data: Data8 + id: str + livemode: Optional[bool] = None + object: Object2 + pending_webhooks: Optional[int] = None + request: Optional[StripeRequestInfo] = None + type: Type31 + + class GeminiCandidate(BaseModel): citationMetadata: Optional[GeminiCitationMetadata] = None content: Optional[GeminiContent] = None @@ -3562,12 +5946,67 @@ class GeminiCandidate(BaseModel): class GeminiGenerateContentResponse(BaseModel): candidates: Optional[List[GeminiCandidate]] = None promptFeedback: Optional[GeminiPromptFeedback] = None + usageMetadata: Optional[GeminiUsageMetadata] = None class InputItem(RootModel[Union[EasyInputMessage, Item]]): root: Union[EasyInputMessage, Item] +class Node(BaseModel): + author: Optional[str] = None + banner_url: Optional[str] = Field(None, description="URL to the node's banner.") + category: Optional[str] = Field(None, description='The category of the node.') + created_at: Optional[datetime] = Field( + None, description='The date and time when the node was created' + ) + description: Optional[str] = None + downloads: Optional[int] = Field( + None, description='The number of downloads of the node.' + ) + github_stars: Optional[int] = Field( + None, description='Number of stars on the GitHub repository.' + ) + icon: Optional[str] = Field(None, description="URL to the node's icon.") + id: Optional[str] = Field(None, description='The unique identifier of the node.') + latest_version: Optional[NodeVersion] = None + license: Optional[str] = Field( + None, description="The path to the LICENSE file in the node's repository." + ) + name: Optional[str] = Field(None, description='The display name of the node.') + preempted_comfy_node_names: Optional[List[str]] = Field( + None, description='A list of Comfy node names that are preempted by this node.' + ) + publisher: Optional[Publisher] = None + rating: Optional[float] = Field(None, description='The average rating of the node.') + repository: Optional[str] = Field(None, description="URL to the node's repository.") + search_ranking: Optional[int] = Field( + None, + description="A numerical value representing the node's search ranking, used for sorting search results.", + ) + status: Optional[NodeStatus] = None + status_detail: Optional[str] = Field( + None, description='The status detail of the node.' + ) + supported_accelerators: Optional[List[str]] = Field( + None, + description='List of accelerators (e.g. CUDA, DirectML, ROCm) that this node supports', + ) + supported_comfyui_frontend_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI frontend' + ) + supported_comfyui_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI' + ) + supported_os: Optional[List[str]] = Field( + None, description='List of operating systems that this node supports' + ) + tags: Optional[List[str]] = None + translations: Optional[Dict[str, Dict[str, Any]]] = Field( + None, description='Translations of node metadata in different languages.' + ) + + class OpenAICreateResponse(CreateModelResponseProperties, ResponseProperties): include: Optional[List[Includable]] = Field( None, @@ -3615,8 +6054,73 @@ class OpenAIResponse(ModelResponseProperties, ResponseProperties): parallel_tool_calls: Optional[bool] = Field( True, description='Whether to allow the model to run tool calls in parallel.\n' ) - status: Optional[Status6] = Field( + status: Optional[Status7] = Field( None, description='The status of the response generation. One of `completed`, `failed`, `in_progress`, or `incomplete`.', ) usage: Optional[ResponseUsage] = None + + +class ResponseCompletedEvent(BaseModel): + response: OpenAIResponse + type: Type19 = Field( + ..., description='The type of the event. Always `response.completed`.' + ) + + +class ResponseCreatedEvent(BaseModel): + response: OpenAIResponse + type: Type22 = Field( + ..., description='The type of the event. Always `response.created`.' + ) + + +class ResponseFailedEvent(BaseModel): + response: OpenAIResponse + type: Type24 = Field( + ..., description='The type of the event. Always `response.failed`.\n' + ) + + +class ResponseInProgressEvent(BaseModel): + response: OpenAIResponse + type: Type27 = Field( + ..., description='The type of the event. Always `response.in_progress`.\n' + ) + + +class ResponseIncompleteEvent(BaseModel): + response: OpenAIResponse + type: Type28 = Field( + ..., description='The type of the event. Always `response.incomplete`.\n' + ) + + +class OpenAIResponseStreamEvent( + RootModel[ + Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseErrorEvent, + ] + ] +): + root: Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseErrorEvent, + ] = Field(..., description='Events that can be emitted during response streaming') diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 626e8d277..9f43d4d09 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel): type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) - model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5 + model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123 face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index df846d5dd..97bfe20e6 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -8,10 +8,10 @@ from typing import Optional from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( - Veo2GenVidRequest, - Veo2GenVidResponse, - Veo2GenVidPollRequest, - Veo2GenVidPollResponse + VeoGenVidRequest, + VeoGenVidResponse, + VeoGenVidPollRequest, + VeoGenVidPollResponse ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor): return tensor_to_base64_string(scaled_image) -def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]: +def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]: if ( poll_response.response and hasattr(poll_response.response, "videos") @@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC): "default": None, "tooltip": "Optional reference image to guide video generation", }), + "model": ( + IO.COMBO, + { + "options": ["veo-2.0-generate-001"], + "default": "veo-2.0-generate-001", + "tooltip": "Veo 2 model to use for video generation", + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", @@ -141,7 +149,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): RETURN_TYPES = (IO.VIDEO,) FUNCTION = "generate_video" CATEGORY = "api node/video/Veo" - DESCRIPTION = "Generates videos from text prompts using Google's Veo API" + DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" API_NODE = True def generate_video( @@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC): person_generation="ALLOW", seed=0, image=None, + model="veo-2.0-generate-001", + generate_audio=False, unique_id: Optional[str] = None, **kwargs, ): @@ -188,16 +198,19 @@ class VeoVideoGenerationNode(ComfyNodeABC): parameters["negativePrompt"] = negative_prompt if seed > 0: parameters["seed"] = seed + # Only add generateAudio for Veo 3 models + if "veo-3.0" in model: + parameters["generateAudio"] = generate_audio # Initial request to start video generation initial_operation = SynchronousOperation( endpoint=ApiEndpoint( - path="/proxy/veo/generate", + path=f"/proxy/veo/{model}/generate", method=HttpMethod.POST, - request_model=Veo2GenVidRequest, - response_model=Veo2GenVidResponse + request_model=VeoGenVidRequest, + response_model=VeoGenVidResponse ), - request=Veo2GenVidRequest( + request=VeoGenVidRequest( instances=instances, parameters=parameters ), @@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC): # Define the polling operation poll_operation = PollingOperation( poll_endpoint=ApiEndpoint( - path="/proxy/veo/poll", + path=f"/proxy/veo/{model}/poll", method=HttpMethod.POST, - request_model=Veo2GenVidPollRequest, - response_model=Veo2GenVidPollResponse + request_model=VeoGenVidPollRequest, + response_model=VeoGenVidPollResponse ), completed_statuses=["completed"], failed_statuses=[], # No failed statuses, we'll handle errors after polling status_extractor=status_extractor, progress_extractor=progress_extractor, - request=Veo2GenVidPollRequest( + request=VeoGenVidPollRequest( operationName=operation_name ), auth_kwargs=kwargs, @@ -298,11 +311,64 @@ class VeoVideoGenerationNode(ComfyNodeABC): return (VideoFromFile(video_io),) -# Register the node +class Veo3VideoGenerationNode(VeoVideoGenerationNode): + """ + Generates videos from text prompts using Google's Veo 3 API. + + Supported models: + - veo-3.0-generate-001 + - veo-3.0-fast-generate-001 + + This node extends the base Veo node with Veo 3 specific features including + audio generation and fixed 8-second duration. + """ + + @classmethod + def INPUT_TYPES(s): + parent_input = super().INPUT_TYPES() + + # Update model options for Veo 3 + parent_input["optional"]["model"] = ( + IO.COMBO, + { + "options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], + "default": "veo-3.0-generate-001", + "tooltip": "Veo 3 model to use for video generation", + }, + ) + + # Add generateAudio parameter + parent_input["optional"]["generate_audio"] = ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Generate audio for the video. Supported by all Veo 3 models.", + } + ) + + # Update duration constraints for Veo 3 (only 8 seconds supported) + parent_input["optional"]["duration_seconds"] = ( + IO.INT, + { + "default": 8, + "min": 8, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + }, + ) + + return parent_input + + +# Register the nodes NODE_CLASS_MAPPINGS = { "VeoVideoGenerationNode": VeoVideoGenerationNode, + "Veo3VideoGenerationNode": Veo3VideoGenerationNode, } NODE_DISPLAY_NAME_MAPPINGS = { - "VeoVideoGenerationNode": "Google Veo2 Video Generation", + "VeoVideoGenerationNode": "Google Veo 2 Video Generation", + "Veo3VideoGenerationNode": "Google Veo 3 Video Generation", } From 5be6fd09ffb46cfdff240fb5b96dd8c06b2a0344 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 5 Aug 2025 15:48:56 +0800 Subject: [PATCH 021/325] Update template to 0.1.48 (#9182) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ffa7dce65..470060ab4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.47 +comfyui-workflow-templates==0.1.48 comfyui-embedded-docs==0.2.4 torch torchsde From d044a243986700aae552acdebf7e767ae8282e37 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 5 Aug 2025 03:12:27 -0700 Subject: [PATCH 022/325] Fix default shift and any latent size for qwen image model. (#9186) --- comfy/ldm/qwen_image/model.py | 9 +++++---- comfy/supported_models.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index ff631a60f..c15ab8e40 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -8,7 +8,7 @@ from einops import repeat from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND - +import comfy.ldm.common_dit class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -364,8 +364,9 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb = self.pos_embeds(x, context) - orig_shape = x.shape - hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) @@ -396,4 +397,4 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) - return hidden_states.reshape(orig_shape) + return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 880055bd3..156ff9e26 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1237,7 +1237,7 @@ class QwenImage(supported_models_base.BASE): sampling_settings = { "multiplier": 1.0, - "shift": 2.6, + "shift": 1.15, } memory_usage_factor = 1.8 #TODO From da1ad9b5163fb848f3ec87ccc4fd0c8069f6eff0 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 5 Aug 2025 19:24:12 +0800 Subject: [PATCH 023/325] Update template to 0.1.51 (#9187) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 470060ab4..9a6b04cf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.48 +comfyui-workflow-templates==0.1.51 comfyui-embedded-docs==0.2.4 torch torchsde From 32a95bba8ac91e8a610c35ce4d9963d2453118c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 5 Aug 2025 07:33:02 -0400 Subject: [PATCH 024/325] ComfyUI version 0.3.49 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 7b29e338d..5e2d09c81 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.48" +__version__ = "0.3.49" diff --git a/pyproject.toml b/pyproject.toml index 256677fad..3c530ba85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.48" +version = "0.3.49" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From d8c51ba15aef6b0df86a7ea0203881be55d7579b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 5 Aug 2025 04:41:18 -0700 Subject: [PATCH 025/325] Add Qwen Image model to readme. (#9191) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2abd8e600..119098f5c 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) + - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From 9126c0cfe49508a64c429f97b45664b241aab3f2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 6 Aug 2025 01:07:04 -0700 Subject: [PATCH 026/325] Qwen Image model merging node. (#9202) --- .../nodes_model_merging_model_specific.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 2c93cd84f..55eb3ccfe 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -314,6 +314,29 @@ class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBl return {"required": arg_dict} +class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embeds."] = argument + arg_dict["img_in."] = argument + arg_dict["txt_norm."] = argument + arg_dict["txt_in."] = argument + arg_dict["time_text_embed."] = argument + + for i in range(60): + arg_dict["transformer_blocks.{}.".format(i)] = argument + + arg_dict["proj_out."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -329,4 +352,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeWAN2_1": ModelMergeWAN2_1, "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, + "ModelMergeQwenImage": ModelMergeQwenImage, } From 4c3e57b0ae9fb7ff1322977915efe7e98544d15d Mon Sep 17 00:00:00 2001 From: flybirdxx <1119577418@qq.com> Date: Thu, 7 Aug 2025 01:23:11 +0800 Subject: [PATCH 027/325] Fixed an issue where qwenLora could not be loaded properly. (#9208) --- comfy/lora.py | 9 +++++++++ comfy/weight_adapter/lora.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 387d5c52a..6686b7229 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -293,6 +293,15 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.QwenImage): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format + key_lora = k[len("diffusion_model."):-len(".weight")] + # Direct mapping for transformer_blocks format (QwenImage LoRA format) + key_map["{}".format(key_lora)] = k + # Support transformer prefix format + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 729dbd9e6..47aa17d13 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase): diffusers3_lora = "{}.lora.up.weight".format(x) mochi_lora = "{}.lora_B".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + qwen_default_lora = "{}.lora_B.default.weight".format(x) A_name = None if regular_lora in lora.keys(): @@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase): A_name = transformers_lora B_name = "{}.lora_linear_layer.down.weight".format(x) mid_name = None + elif qwen_default_lora in lora.keys(): + A_name = qwen_default_lora + B_name = "{}.lora_A.default.weight".format(x) + mid_name = None if A_name is not None: mid = None From 32691b16f4e1a897e461e77f9d6dceba2d6f0cd1 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 7 Aug 2025 01:26:29 +0800 Subject: [PATCH 028/325] Update template to 0.1.52 (#9206) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9a6b04cf2..d6926d610 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.51 +comfyui-workflow-templates==0.1.52 comfyui-embedded-docs==0.2.4 torch torchsde From 37d620a6b85f61b824363ed8170db373726ca45a Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 6 Aug 2025 16:52:39 -0700 Subject: [PATCH 029/325] Update frontend to v1.24.3 (#9175) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d6926d610..2f4692b03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.23.4 +comfyui-frontend-package==1.24.4 comfyui-workflow-templates==0.1.52 comfyui-embedded-docs==0.2.4 torch From 05df2df489f6b237f63c5f7d42a943ae2be417e9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 7 Aug 2025 08:20:40 -0700 Subject: [PATCH 030/325] Fix RepeatLatentBatch not working on multi dim latents. (#9227) --- nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 9bedbcaca..9448f9c1b 100644 --- a/nodes.py +++ b/nodes.py @@ -1229,12 +1229,12 @@ class RepeatLatentBatch: s = samples.copy() s_in = samples["samples"] - s["samples"] = s_in.repeat((amount, 1,1,1)) + s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1))) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] if masks.shape[0] < s_in.shape[0]: - masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] - s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1))) if "batch_index" in s: offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] From 42974a448c39af50c5f72d8c70267f9fe2971cd2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 7 Aug 2025 14:54:09 -0700 Subject: [PATCH 031/325] _ui.py import torchaudio safety check (#9234) * Added safety around torchaudio import in _ui.py * Trusted cursor too much, fixed torchaudio bool --- comfy_api/latest/_ui.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6b8a39d58..61597038f 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -9,7 +9,11 @@ from typing import Type import av import numpy as np import torch -import torchaudio +try: + import torchaudio + TORCH_AUDIO_AVAILABLE = True +except ImportError: + TORCH_AUDIO_AVAILABLE = False from PIL import Image as PILImage from PIL.PngImagePlugin import PngInfo @@ -302,6 +306,8 @@ class AudioSaveHelper: # Resample if necessary if sample_rate != audio["sample_rate"]: + if not TORCH_AUDIO_AVAILABLE: + raise Exception("torchaudio is not available; cannot resample audio.") waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) # Create output with specified format From bf2a1b5b1ef72b240454f3ac44f5209af45efe00 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 8 Aug 2025 06:37:50 +0300 Subject: [PATCH 032/325] async API nodes (#9129) * converted API nodes to async * converted BFL API nodes to async * fixed client bug; converted gemini, ideogram, minimax * fixed client bug; converted openai nodes * fixed client bug; converted moonvalley, pika nodes * fixed client bug; converted kling, luma nodes * converted pixverse, rodin nodes * converted tripo, veo2 * converted recraft nodes * add lost log_request_response call --- comfy_api_nodes/apinode_utils.py | 152 ++--- comfy_api_nodes/apis/client.py | 901 +++++++++++----------------- comfy_api_nodes/nodes_bfl.py | 134 +++-- comfy_api_nodes/nodes_gemini.py | 4 +- comfy_api_nodes/nodes_ideogram.py | 22 +- comfy_api_nodes/nodes_kling.py | 130 ++-- comfy_api_nodes/nodes_luma.py | 70 ++- comfy_api_nodes/nodes_minimax.py | 14 +- comfy_api_nodes/nodes_moonvalley.py | 38 +- comfy_api_nodes/nodes_openai.py | 28 +- comfy_api_nodes/nodes_pika.py | 63 +- comfy_api_nodes/nodes_pixverse.py | 46 +- comfy_api_nodes/nodes_recraft.py | 44 +- comfy_api_nodes/nodes_rodin.py | 147 ++--- comfy_api_nodes/nodes_runway.py | 54 +- comfy_api_nodes/nodes_stability.py | 23 +- comfy_api_nodes/nodes_tripo.py | 69 ++- comfy_api_nodes/nodes_veo2.py | 15 +- 18 files changed, 878 insertions(+), 1076 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 788e2803f..f953f86df 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,4 +1,5 @@ from __future__ import annotations +import aiohttp import io import logging import mimetypes @@ -21,7 +22,6 @@ from server import PromptServer import numpy as np from PIL import Image -import requests import torch import math import base64 @@ -30,7 +30,7 @@ from io import BytesIO import av -def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: +async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output. Args: @@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr Returns: A Comfy node `VIDEO` output. """ - video_io = download_url_to_bytesio(video_url, timeout) + video_io = await download_url_to_bytesio(video_url, timeout) if video_io is None: error_msg = f"Failed to download video from {video_url}" logging.error(error_msg) @@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: return s -def validate_and_cast_response( +async def validate_and_cast_response( response, timeout: int = None, node_id: Union[str, None] = None ) -> torch.Tensor: """Validates and casts a response to a torch.Tensor. @@ -86,35 +86,24 @@ def validate_and_cast_response( image_tensors: list[torch.Tensor] = [] # Process each image in the data array - for image_data in data: - image_url = image_data.url - b64_data = image_data.b64_json + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + for img_data in data: + img_bytes: bytes + if img_data.b64_json: + img_bytes = base64.b64decode(img_data.b64_json) + elif img_data.url: + if node_id: + PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id) + async with session.get(img_data.url) as resp: + if resp.status != 200: + raise ValueError("Failed to download generated image") + img_bytes = await resp.read() + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") - if not image_url and not b64_data: - raise ValueError("No image was generated in the response") - - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) - - elif image_url: - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {image_url}", node_id - ) - img_response = requests.get(image_url, timeout=timeout) - if img_response.status_code != 200: - raise ValueError("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) - - img = img.convert("RGBA") - - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array) - - # Add to list of tensors - image_tensors.append(img_tensor) + pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) return torch.stack(image_tensors, dim=0) @@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: +async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: """Downloads content from a URL using requests and returns it as BytesIO. Args: @@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: Returns: BytesIO object containing the downloaded content. """ - response = requests.get(url, stream=True, timeout=timeout) - response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(response.content) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.get(url) as resp: + resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) + return BytesIO(await resp.read()) def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: @@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch return torch.from_numpy(image_array).unsqueeze(0) -def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: +async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = download_url_to_bytesio(url, timeout) + image_bytesio = await download_url_to_bytesio(url, timeout) return bytesio_to_image_tensor(image_bytesio) -def process_image_response(response: requests.Response) -> torch.Tensor: +def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response.content)) + return bytesio_to_image_tensor(BytesIO(response_content)) def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: @@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str: return f"data:{mime_type};base64,{base64_string}" -def upload_file_to_comfyapi( +async def upload_file_to_comfyapi( file_bytes_io: BytesIO, filename: str, - upload_mime_type: str, + upload_mime_type: Optional[str], auth_kwargs: Optional[dict[str, str]] = None, ) -> str: """ @@ -354,7 +345,10 @@ def upload_file_to_comfyapi( Returns: The download URL for the uploaded file. """ - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) operation = SynchronousOperation( endpoint=ApiEndpoint( path="/customers/storage", @@ -366,12 +360,8 @@ def upload_file_to_comfyapi( auth_kwargs=auth_kwargs, ) - response: UploadResponse = operation.execute() - upload_response = ApiClient.upload_file( - response.upload_url, file_bytes_io, content_type=upload_mime_type - ) - upload_response.raise_for_status() - + response: UploadResponse = await operation.execute() + await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) return response.download_url @@ -399,7 +389,7 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -def upload_video_to_comfyapi( +async def upload_video_to_comfyapi( video: VideoInput, auth_kwargs: Optional[dict[str, str]] = None, container: VideoContainer = VideoContainer.MP4, @@ -439,9 +429,7 @@ def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return upload_file_to_comfyapi( - video_bytes_io, filename, upload_mime_type, auth_kwargs - ) + return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: @@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio( return audio_bytes_io -def upload_audio_to_comfyapi( +async def upload_audio_to_comfyapi( audio: AudioInput, auth_kwargs: Optional[dict[str, str]] = None, container_format: str = "mp4", @@ -527,7 +515,7 @@ def upload_audio_to_comfyapi( audio_data_np, sample_rate, container_format, codec_name ) - return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) + return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) def audio_to_base64_string( @@ -544,7 +532,7 @@ def audio_to_base64_string( return base64.b64encode(audio_bytes).decode("utf-8") -def upload_images_to_comfyapi( +async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str, str]] = None, @@ -561,55 +549,15 @@ def upload_images_to_comfyapi( mime_type: Optional MIME type for the image. """ # if batch, try to upload each file if max_images is greater than 0 - idx_image = 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 - batch_length = 1 - if is_batch: - batch_length = image.shape[0] - while True: - curr_image = image - if len(image.shape) > 3: - curr_image = image[idx_image] - # get BytesIO version of image - img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type) - # first, request upload/download urls from comfy API - if not mime_type: - request_object = UploadRequest(file_name=img_binary.name) - else: - request_object = UploadRequest( - file_name=img_binary.name, content_type=mime_type - ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - response = operation.execute() + batch_len = image.shape[0] if is_batch else 1 - upload_response = ApiClient.upload_file( - response.upload_url, img_binary, content_type=mime_type - ) - # verify success - try: - upload_response.raise_for_status() - except requests.exceptions.HTTPError as e: - raise ValueError(f"Could not upload one or more images: {e}") from e - # add download_url to list - download_urls.append(response.download_url) - - idx_image += 1 - # stop uploading additional files if done - if is_batch and max_images > 0: - if idx_image >= max_images: - break - if idx_image >= batch_length: - break + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) + download_urls.append(url) return download_urls diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 2a4bac88b..4ad0b783b 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -43,7 +43,7 @@ operation = ApiOperation( endpoint=user_info_endpoint, request=request ) -user_profile = operation.execute(client=api_client) # Returns immediately with the result +user_profile = await operation.execute(client=api_client) # Returns immediately with the result # Example 2: Asynchronous API Operation with Polling @@ -87,18 +87,19 @@ operation = PollingOperation( ) # This will make the initial request and then poll until completion -result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done +result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done """ from __future__ import annotations +import aiohttp +import asyncio import logging -import time import io import socket +from aiohttp.client_exceptions import ClientError, ClientResponseError from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from enum import Enum import json -import requests from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field import uuid # For generating unique operation IDs @@ -174,6 +175,7 @@ class ApiClient: retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, retry_status_codes: Optional[Tuple[int, ...]] = None, + session: Optional[aiohttp.ClientSession] = None, ): self.base_url = base_url self.auth_token = auth_token @@ -186,13 +188,16 @@ class ApiClient: # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), # 500, 502, 503, 504 (Server Errors) self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) + self._session: Optional[aiohttp.ClientSession] = session + self._owns_session = session is None # Track if we have to close it - def _generate_operation_id(self, path: str) -> str: + @staticmethod + def _generate_operation_id(path: str) -> str: """Generates a unique operation ID for logging.""" return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" + @staticmethod def _create_json_payload_args( - self, data: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: @@ -203,31 +208,53 @@ class ApiClient: def _create_form_data_args( self, - data: Dict[str, Any], - files: Dict[str, Any], + data: Dict[str, Any] | None, + files: Dict[str, Any] | None, headers: Optional[Dict[str, str]] = None, - multipart_parser = None, + multipart_parser: Callable | None = None, ) -> Dict[str, Any]: if headers and "Content-Type" in headers: del headers["Content-Type"] - if multipart_parser: + if multipart_parser and data: data = multipart_parser(data) - return { - "data": data, - "files": files, - "headers": headers, - } + form = aiohttp.FormData(default_to_multipart=True) + if data: # regular text fields + for k, v in data.items(): + if v is None: + continue # aiohttp fails to serialize "None" values + # aiohttp expects strings or bytes; convert enums etc. + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if files: + file_iter = files if isinstance(files, list) else files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue # aiohttp fails to serialize "None" values + # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple + if isinstance(file_obj, tuple): + filename, file_value, content_type = self._unpack_tuple(file_obj) + else: + file_value = file_obj + filename = getattr(file_obj, "name", field_name) + content_type = "application/octet-stream" + + form.add_field( + name=field_name, + value=file_value, + filename=filename, + content_type=content_type, + ) + return {"data": form, "headers": headers or {}} + + @staticmethod def _create_urlencoded_form_data_args( - self, data: Dict[str, Any], headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: headers = headers or {} headers["Content-Type"] = "application/x-www-form-urlencoded" - return { "data": data, "headers": headers, @@ -244,7 +271,7 @@ class ApiClient: return headers - def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: """ Check connectivity to determine if network issues are local or server-related. @@ -258,52 +285,39 @@ class ApiClient: "internet_accessible": False, "api_accessible": False, "is_local_issue": False, - "is_api_issue": False + "is_api_issue": False, } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results # cannot reach the internet – early exit - # First check basic internet connectivity using a reliable external site - try: - # Use a reliable external domain for checking basic connectivity - check_response = requests.get("https://www.google.com", - timeout=5.0, - verify=self.verify_ssl) - if check_response.status_code < 500: - results["internet_accessible"] = True - except (requests.RequestException, socket.error): - results["internet_accessible"] = False - results["is_local_issue"] = True - return results - - # Now check API server connectivity - try: - # Extract domain from the target URL to do a simpler health check - parsed_url = urlparse(target_url) - api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" - - # Try to reach the API domain - api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) - if api_response.status_code < 500: - results["api_accessible"] = True - else: - results["api_accessible"] = False - results["is_api_issue"] = True - except requests.RequestException: - results["api_accessible"] = False - # If we can reach the internet but not the API, it's an API issue - results["is_api_issue"] = True + # Now check API health endpoint + parsed = urlparse(target_url) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + try: + async with session.get(health_url, ssl=self.verify_ssl) as resp: + results["api_accessible"] = resp.status < 500 + except ClientError: + pass # leave as False + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] return results - def request( + async def request( self, method: str, path: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, headers: Optional[Dict[str, str]] = None, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, retry_count: int = 0, # Used internally for tracking retries ) -> Dict[str, Any]: """ @@ -327,18 +341,19 @@ class ApiClient: ApiServerError: If the API server is unreachable but internet is working Exception: For other request failures """ - # Use urljoin but ensure path is relative to avoid absolute path behavior - relative_path = path.lstrip('/') + + # Build full URL and merge headers + relative_path = path.lstrip("/") url = urljoin(self.base_url, relative_path) - self.check_auth(self.auth_token, self.comfy_api_key) - # Combine default headers with any provided headers + self._check_auth(self.auth_token, self.comfy_api_key) + request_headers = self.get_headers() if headers: request_headers.update(headers) - - # Let requests handle the content type when files are present. if files: - del request_headers["Content-Type"] + request_headers.pop("Content-Type", None) + if params: + params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug(f"[DEBUG] Files: {files}") @@ -346,11 +361,9 @@ class ApiClient: logging.debug(f"[DEBUG] Data: {data}") if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data, request_headers) + payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args( - data, files, request_headers, multipart_parser - ) + payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) else: payload_args = self._create_json_payload_args(data, request_headers) @@ -361,220 +374,67 @@ class ApiClient: request_url=url, request_headers=request_headers, request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]" + request_data=data if content_type == "application/json" else "[form-data or other]", ) + session = await self._get_session() try: - response = requests.request( - method=method, - url=url, + async with session.request( + method, + url, params=params, - timeout=self.timeout, - verify=self.verify_ssl, + ssl=self.verify_ssl, **payload_args, - ) + ) as resp: + if resp.status >= 400: + try: + error_data = await resp.json() + except (aiohttp.ContentTypeError, json.JSONDecodeError): + error_data = await resp.text() - # Check if we should retry based on status code - if (response.status_code in self.retry_status_codes and - retry_count < self.max_retries): + return await self._handle_http_error( + ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), + operation_id, + method, + url, + params, + data, + files, + headers, + content_type, + multipart_parser, + retry_count=retry_count, + response_content=error_data, + ) - # Calculate delay with exponential backoff - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - - logging.warning( - f"Request failed with status {response.status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # Raise exception for error status codes - response.raise_for_status() - - # Log successful response - response_content_to_log = response.content - try: - # Attempt to parse JSON for prettier logging, fallback to raw content - response_content_to_log = response.json() - except json.JSONDecodeError: - pass # Keep as bytes/str if not JSON - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, # Pass request details again for context in log - request_url=url, - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content=response_content_to_log - ) - - except requests.ConnectionError as e: - error_message = f"ConnectionError: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - error_message=error_message - ) - # Only perform connectivity check if we've exhausted all retries - if retry_count >= self.max_retries: - # Check connectivity to determine if it's a local or API issue - connectivity = self._check_connectivity(self.base_url) - - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - elif connectivity["is_api_issue"]: - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - # If we haven't exhausted retries yet, retry the request - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Connection error: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # If we've exhausted retries and didn't identify the specific issue, - # raise a generic exception - final_error_message = ( - f"Unable to connect to the API server after {self.max_retries} attempts. " - f"Please check your internet connection or try again later." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.Timeout as e: - error_message = f"Timeout: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - error_message=error_message - ) - # Retry timeouts if we haven't exhausted retries - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Request timed out. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - final_error_message = ( - f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " - f"The server might be experiencing high load or the operation is taking longer than expected." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.HTTPError as e: - status_code = e.response.status_code if hasattr(e, "response") else None - original_error_message = f"HTTP Error: {str(e)}" - error_content_for_log = None - if hasattr(e, "response") and e.response is not None: - error_content_for_log = e.response.content + # Success – parse JSON (safely) and log try: - error_content_for_log = e.response.json() - except json.JSONDecodeError: - pass + payload = await resp.json() + response_content_to_log = payload + except (aiohttp.ContentTypeError, json.JSONDecodeError): + payload = {} + response_content_to_log = await resp.text() - - # Try to extract detailed error message from JSON response for user display - # but log the full error content. - user_display_error_message = original_error_message - - try: - if hasattr(e, "response") and e.response is not None and e.response.content: - error_json = e.response.json() - if "error" in error_json and "message" in error_json["error"]: - user_display_error_message = f"API Error: {error_json['error']['message']}" - if "type" in error_json["error"]: - user_display_error_message += f" (Type: {error_json['error']['type']})" - elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict - user_display_error_message = f"API Error: {json.dumps(error_json)}" - else: # Non-dict JSON error - user_display_error_message = f"API Error: {str(error_json)}" - except json.JSONDecodeError: - # If not JSON, use the raw content if it's not too long, or a summary - if hasattr(e, "response") and e.response is not None and e.response.content: - raw_content = e.response.content.decode(errors='ignore') - if len(raw_content) < 200: # Arbitrary limit for display - user_display_error_message = f"API Error (raw): {raw_content}" - else: - user_display_error_message = f"API Error (raw, status {status_code})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - response_status_code=status_code, - response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, - response_content=error_content_for_log, - error_message=original_error_message # Log the original exception string as error - ) - - logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response is not None and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") - - # Retry if the status code is in our retry list and we haven't exhausted retries - if (status_code in self.retry_status_codes and - retry_count < self.max_retries): - - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"HTTP error {status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, ) - time.sleep(delay) - return self.request( - method=method, - path=path, + return payload + + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + # Treat as *connection* problem – optionally retry, else escalate + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, + self.max_retries, str(e)) + await asyncio.sleep(delay) + return await self.request( + method, + path, params=params, data=data, files=files, @@ -583,40 +443,34 @@ class ApiClient: multipart_parser=multipart_parser, retry_count=retry_count + 1, ) + # One final connectivity check for diagnostics + connectivity = await self._check_connectivity(self.base_url) + if connectivity["is_local_issue"]: + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + raise ApiServerError( + f"The API server at {self.base_url} is currently unreachable. " + f"The service may be experiencing issues. Please try again later." + ) from e - # Specific error messages for common status codes for user display - if status_code == 401: - user_display_error_message = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_display_error_message = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_display_error_message = "Rate Limit Exceeded: Please try again later." - # else, user_display_error_message remains as parsed from response or original HTTPError string - - raise Exception(user_display_error_message) # Raise with the user-friendly message - - # Parse and return JSON response - if response.content: - return response.json() - return {} - - def check_auth(self, auth_token, comfy_api_key): + @staticmethod + def _check_auth(auth_token, comfy_api_key): """Verify that an auth token is present or comfy_api_key is present""" if auth_token is None and comfy_api_key is None: raise Exception("Unauthorized: Please login first to use this node.") return auth_token or comfy_api_key @staticmethod - def upload_file( + async def upload_file( upload_url: str, file: io.BytesIO | str, content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> aiohttp.ClientResponse: """Upload a file to the API with retry logic. Args: @@ -627,112 +481,167 @@ class ApiClient: retry_delay: Initial delay between retries in seconds retry_backoff_factor: Multiplier for the delay after each retry """ - headers = {} + headers: Dict[str, str] = {} + skip_auto_headers: set[str] = set() if content_type: headers["Content-Type"] = content_type + else: + # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. + skip_auto_headers.add("Content-Type") - # Prepare the file data + # Extract file bytes if isinstance(file, io.BytesIO): - file.seek(0) # Ensure we're at the start of the file + file.seek(0) data = file.read() elif isinstance(file, str): with open(file, "rb") as f: data = f.read() else: - raise ValueError("File must be either a BytesIO object or a file path string") + raise ValueError("File must be BytesIO or str path") - # Try the upload with retries - last_exception = None - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads - - # Log initial attempt (without full file data for brevity) + operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" request_logger.log_request_response( operation_id=operation_id, request_method="PUT", request_url=upload_url, request_headers=headers, - request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" + request_data=f"[File data {len(data)} bytes]", ) - for retry_attempt in range(max_retries + 1): + delay = retry_delay + for attempt in range(max_retries + 1): try: - response = requests.put(upload_url, data=data, headers=headers) - response.raise_for_status() + timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.put( + upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, + ) as resp: + resp.raise_for_status() + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + return resp + except (ClientError, asyncio.TimeoutError) as e: request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, # For context - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content="File uploaded successfully." # Or response.text if available + request_method="PUT", + request_url=upload_url, + response_status_code=e.status if hasattr(e, "status") else None, + response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_content=None, + error_message=f"{type(e).__name__}: {str(e)}", ) - return response - - except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: - last_exception = e - error_message_for_log = f"{type(e).__name__}: {str(e)}" - response_content_for_log = None - status_code_for_log = None - headers_for_log = None - - if hasattr(e, 'response') and e.response is not None: - status_code_for_log = e.response.status_code - headers_for_log = dict(e.response.headers) - try: - response_content_for_log = e.response.json() - except json.JSONDecodeError: - response_content_for_log = e.response.content - - - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - response_status_code=status_code_for_log, - response_headers=headers_for_log, - response_content=response_content_for_log, - error_message=error_message_for_log - ) - - if retry_attempt < max_retries: - delay = retry_delay * (retry_backoff_factor ** retry_attempt) + if attempt < max_retries: logging.warning( - f"File upload failed: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" + "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) ) - time.sleep(delay) + await asyncio.sleep(delay) + delay *= retry_backoff_factor else: - break # Max retries reached + raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - # If we've exhausted all retries, determine the final error type and raise - final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" - try: - # Check basic internet connectivity - check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired - if check_response.status_code >= 500: # Google itself has an issue (rare) - final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " - f"(status {check_response.status_code}). Original error: {str(last_exception)}") - # Not raising LocalNetworkError here as Google itself might be down. - # If Google is reachable, the issue is likely with the upload server or a more specific local problem - # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). - # The original last_exception is probably most relevant. + async def _handle_http_error( + self, + exc: ClientResponseError, + operation_id: str, + *req_meta, + retry_count: int, + response_content: dict | str = "", + ) -> Dict[str, Any]: + status_code = exc.status + if status_code == 401: + user_friendly = "Unauthorized: Please login first to use this node." + elif status_code == 402: + user_friendly = "Payment Required: Please add credits to your account to use this node." + elif status_code == 409: + user_friendly = "There is a problem with your account. Please contact support@comfy.org." + elif status_code == 429: + user_friendly = "Rate Limit Exceeded: Please try again later." + else: + if isinstance(response_content, dict): + if "error" in response_content and "message" in response_content["error"]: + user_friendly = f"API Error: {response_content['error']['message']}" + if "type" in response_content["error"]: + user_friendly += f" (Type: {response_content['error']['type']})" + else: # Handle cases where error is just a JSON dict with unknown format + user_friendly = f"API Error: {json.dumps(response_content)}" + else: + if len(response_content) < 200: # Arbitrary limit for display + user_friendly = f"API Error (raw): {response_content}" + else: + user_friendly = f"API Error (raw, status {response_content})" - except (requests.RequestException, socket.error) as conn_check_exc: - # Could not reach Google, likely a local network issue - final_error_message = (f"Failed to upload file due to network connectivity issues " - f"(cannot reach Google: {str(conn_check_exc)}). " - f"Original upload error: {str(last_exception)}") - request_logger.log_request_response( # Log final failure reason - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message - ) - raise LocalNetworkError(final_error_message) from last_exception - - request_logger.log_request_response( # Log final failure reason if not LocalNetworkError + request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message + request_method=req_meta[0], + request_url=req_meta[1], + response_status_code=exc.status, + response_headers=dict(req_meta[5]) if req_meta[5] else None, + response_content=response_content, + error_message=f"HTTP Error {exc.status}", ) - raise Exception(final_error_message) from last_exception + + logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") + if response_content: + logging.debug(f"[DEBUG] Response content: {response_content}") + + # Retry if eligible + if status_code in self.retry_status_codes and retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + "HTTP error %s. Retrying in %.2fs (%s/%s)", + status_code, + delay, + retry_count + 1, + self.max_retries, + ) + await asyncio.sleep(delay) + return await self.request( + req_meta[0], # method + req_meta[1].replace(self.base_url, ""), # path + params=req_meta[2], + data=req_meta[3], + files=req_meta[4], + headers=req_meta[5], + content_type=req_meta[6], + multipart_parser=req_meta[7], + retry_count=retry_count + 1, + ) + + raise Exception(user_friendly) from exc + + @staticmethod + def _unpack_tuple(t): + """Helper to normalise (filename, file, content_type) tuples.""" + if len(t) == 3: + return t + elif len(t) == 2: + return t[0], t[1], "application/octet-stream" + else: + raise ValueError("files tuple must be (filename, file[, content_type])") + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + self._owns_session = True + return self._session + + async def close(self) -> None: + if self._owns_session and self._session and not self._session.closed: + await self._session.close() + + async def __aenter__(self) -> "ApiClient": + """Allow usage as async‑context‑manager – ensures clean teardown""" + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() class ApiEndpoint(Generic[T, R]): @@ -763,31 +672,28 @@ class ApiEndpoint(Generic[T, R]): class SynchronousOperation(Generic[T, R]): - """ - Represents a single synchronous API operation. - """ + """Represents a single synchronous API operation.""" def __init__( self, endpoint: ApiEndpoint[T, R], request: T, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, timeout: float = 604800.0, verify_ssl: bool = True, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> None: self.endpoint = endpoint self.request = request - self.response = None - self.error = None + self.files = files self.api_base: str = api_base or args.comfy_api_base self.auth_token = auth_token self.comfy_api_key = comfy_api_key @@ -796,91 +702,64 @@ class SynchronousOperation(Generic[T, R]): self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) self.timeout = timeout self.verify_ssl = verify_ssl - self.files = files self.content_type = content_type self.multipart_parser = multipart_parser self.max_retries = max_retries self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one with retry support""" - try: - # Create client if not provided - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - # Convert request model to dict, but use None for EmptyRequest - request_dict = ( - None - if isinstance(self.request, EmptyRequest) - else self.request.model_dump(exclude_none=True) + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + timeout=self.timeout, + verify_ssl=self.verify_ssl, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - # Debug log for request + try: + request_dict: Optional[Dict[str, Any]] + if isinstance(self.request, EmptyRequest): + request_dict = None + else: + request_dict = self.request.model_dump(exclude_none=True) + for k, v in list(request_dict.items()): + if isinstance(v, Enum): + request_dict[k] = v.value + logging.debug( f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" ) logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - # Make the request with built-in retry - resp = client.request( - method=self.endpoint.method.value, - path=self.endpoint.path, - data=request_dict, + response_json = await client.request( + self.endpoint.method.value, + self.endpoint.path, params=self.endpoint.query_params, + data=request_dict, files=self.files, content_type=self.content_type, - multipart_parser=self.multipart_parser + multipart_parser=self.multipart_parser, ) - # Debug log for response logging.debug("=" * 50) logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") + logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") logging.debug("=" * 50) - # Parse and return the response - return self._parse_response(resp) - - except LocalNetworkError as e: - # Propagate specific network error types - logging.error(f"[ERROR] Local network error: {str(e)}") - raise - - except ApiServerError as e: - # Propagate API server errors - logging.error(f"[ERROR] API server error: {str(e)}") - raise - - except Exception as e: - logging.error(f"[ERROR] API Exception: {str(e)}") - raise Exception(str(e)) - - def _parse_response(self, resp): - """Parse response data - can be overridden by subclasses""" - # The response is already the complete object, don't extract just the "data" field - # as that would lose the outer structure (created timestamp, etc.) - - # Parse response using the provided model - self.response = self.endpoint.response_model.model_validate(resp) - logging.debug(f"[DEBUG] Parsed Response: {self.response}") - return self.response + parsed_response = self.endpoint.response_model.model_validate(response_json) + logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") + return parsed_response + finally: + if owns_client: + await client.close() class TaskStatus(str, Enum): @@ -892,23 +771,21 @@ class TaskStatus(str, Enum): class PollingOperation(Generic[T, R]): - """ - Represents an asynchronous API operation that requires polling for completion. - """ + """Represents an asynchronous API operation that requires polling for completion.""" def __init__( self, poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list, - failed_statuses: list, + completed_statuses: list[str], + failed_statuses: list[str], status_extractor: Callable[[R], str], - progress_extractor: Callable[[R], float] = None, - result_url_extractor: Callable[[R], str] = None, + progress_extractor: Callable[[R], float] | None = None, + result_url_extractor: Callable[[R], str] | None = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_retries: int = 3, # Max retries per individual API call @@ -916,7 +793,7 @@ class PollingOperation(Generic[T, R]): retry_backoff_factor: float = 2.0, estimated_duration: Optional[float] = None, node_id: Optional[str] = None, - ): + ) -> None: self.poll_endpoint = poll_endpoint self.request = request self.api_base: str = api_base or args.comfy_api_base @@ -931,100 +808,73 @@ class PollingOperation(Generic[T, R]): self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor self.estimated_duration = estimated_duration - - # Polling configuration - self.status_extractor = status_extractor or ( - lambda x: getattr(x, "status", None) - ) + self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) self.progress_extractor = progress_extractor self.result_url_extractor = result_url_extractor self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses + self.final_response: Optional[R] = None - # For storing response data - self.final_response = None - self.error = None - - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the polling operation using the provided client. If failed, raise an exception.""" + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, + ) try: - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - return self._poll_until_complete(client) - except LocalNetworkError as e: - # Provide clear message for local network issues - raise Exception( - f"Polling failed due to local network issues. Please check your internet connection. " - f"Details: {str(e)}" - ) from e - except ApiServerError as e: - # Provide clear message for API server issues - raise Exception( - f"Polling failed due to API server issues. The service may be experiencing problems. " - f"Please try again later. Details: {str(e)}" - ) from e - except Exception as e: - raise Exception(f"Error during polling: {str(e)}") + return await self._poll_until_complete(client) + finally: + if owns_client: + await client.close() def _display_text_on_node(self, text: str): - """Sends text to the client which will be displayed on the node in the UI""" if not self.node_id: return - PromptServer.instance.send_progress_text(text, self.node_id) - def _display_time_progress_on_node(self, time_completed: int): + def _display_time_progress_on_node(self, time_completed: int | float): if not self.node_id: return - if self.estimated_duration is not None: - estimated_time_remaining = max( - 0, int(self.estimated_duration) - int(time_completed) - ) - message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" + remaining = max(0, int(self.estimated_duration) - time_completed) + message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" else: - message = f"Task in progress: {time_completed:.0f}s" + message = f"Task in progress: {time_completed}s" self._display_text_on_node(message) def _check_task_status(self, response: R) -> TaskStatus: - """Check task status using the status extractor function""" try: status = self.status_extractor(response) if status in self.completed_statuses: return TaskStatus.COMPLETED - elif status in self.failed_statuses: + if status in self.failed_statuses: return TaskStatus.FAILED return TaskStatus.PENDING except Exception as e: - logging.error(f"Error extracting status: {e}") + logging.error("Error extracting status: %s", e) return TaskStatus.PENDING - def _poll_until_complete(self, client: ApiClient) -> R: + async def _poll_until_complete(self, client: ApiClient) -> R: """Poll until the task is complete""" - poll_count = 0 consecutive_errors = 0 max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors if self.progress_extractor: progress = utils.ProgressBar(PROGRESS_BAR_MAX) - while poll_count < self.max_poll_attempts: + status = TaskStatus.PENDING + for poll_count in range(1, self.max_poll_attempts + 1): try: - poll_count += 1 logging.debug(f"[DEBUG] Polling attempt #{poll_count}") request_dict = ( - self.request.model_dump(exclude_none=True) - if self.request is not None - else None + None if self.request is None else self.request.model_dump(exclude_none=True) ) if poll_count == 1: @@ -1036,18 +886,14 @@ class PollingOperation(Generic[T, R]): ) # Query task status - resp = client.request( - method=self.poll_endpoint.method.value, - path=self.poll_endpoint.path, + resp = await client.request( + self.poll_endpoint.method.value, + self.poll_endpoint.path, params=self.poll_endpoint.query_params, data=request_dict, ) - - # Successfully got a response, reset consecutive error count - consecutive_errors = 0 - - # Parse response - response_obj = self.poll_endpoint.response_model.model_validate(resp) + consecutive_errors = 0 # reset on success + response_obj: R = self.poll_endpoint.response_model.model_validate(resp) # Check if task is complete status = self._check_task_status(response_obj) @@ -1065,45 +911,30 @@ class PollingOperation(Generic[T, R]): result_url = self.result_url_extractor(response_obj) if result_url: message = f"Result URL: {result_url}" - else: - message = "Task completed successfully!" logging.debug(f"[DEBUG] {message}") self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: progress.update(100) return self.final_response - elif status == TaskStatus.FAILED: + if status == TaskStatus.FAILED: message = f"Task failed: {json.dumps(resp)}" logging.error(f"[DEBUG] {message}") raise Exception(message) - else: - logging.debug("[DEBUG] Task still pending, continuing to poll...") - - # Wait before polling again - logging.debug( - f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" - ) + logging.debug("[DEBUG] Task still pending, continuing to poll...") + # Task pending – wait for i in range(int(self.poll_interval)): - time_completed = (poll_count * self.poll_interval) + i - self._display_time_progress_on_node(time_completed) - time.sleep(1) + self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) + await asyncio.sleep(1) - except (LocalNetworkError, ApiServerError) as e: - # For network-related errors, increment error count and potentially abort + except (LocalNetworkError, ApiServerError, NetworkError) as e: consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: raise Exception( - f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" + f"Polling aborted after {consecutive_errors} network errors: {str(e)}" ) from e - - # Log the error but continue polling - logging.warning( - f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." - ) - time.sleep(self.poll_interval) - + logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) + await asyncio.sleep(self.poll_interval) except Exception as e: # For other errors, increment count and potentially abort consecutive_errors += 1 @@ -1117,10 +948,10 @@ class PollingOperation(Generic[T, R]): f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " f"Will retry in {self.poll_interval} seconds." ) - time.sleep(self.poll_interval) + await asyncio.sleep(self.poll_interval) # If we've exhausted all polling attempts raise Exception( - f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " - f"The operation may still be running on the server but is taking longer than expected." + f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " + "The operation may still be running on the server but is taking longer than expected." ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index d93fbd778..c09be8d5b 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,3 +1,4 @@ +import asyncio import io from inspect import cleandoc from typing import Union, Optional @@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import ( import numpy as np from PIL import Image -import requests +import aiohttp import torch import base64 import time @@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor): return mask -def handle_bfl_synchronous_operation( +async def handle_bfl_synchronous_operation( operation: SynchronousOperation, timeout_bfl_calls=360, node_id: Union[str, None] = None, ): - response_api: BFLFluxProGenerateResponse = operation.execute() - return _poll_until_generated( + response_api: BFLFluxProGenerateResponse = await operation.execute() + return await _poll_until_generated( response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id ) -def _poll_until_generated( +async def _poll_until_generated( polling_url: str, timeout=360, node_id: Union[str, None] = None ): # used bfl-comfy-nodes to verify code implementation: @@ -66,55 +67,56 @@ def _poll_until_generated( retry_404_seconds = 2 retry_202_seconds = 2 retry_pending_seconds = 1 - request = requests.Request(method=HttpMethod.GET, url=polling_url) - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - response = requests.Session().send(request.prepare()) - if response.status_code == 200: - result = response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - img_response = requests.get(img_url) - return process_image_response(img_response) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." + async with aiohttp.ClientSession() as session: + # NOTE: should True loop be replaced with checking if workflow has been interrupted? + while True: + if node_id: + time_elapsed = time.time() - start_time + PromptServer.instance.send_progress_text( + f"Generating ({time_elapsed:.0f}s)", node_id ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - time.sleep(retry_pending_seconds) - continue - elif response.status_code == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - time.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status_code == 202: - time.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") + + async with session.get(polling_url) as response: + if response.status == 200: + result = await response.json() + if result["status"] == BFLStatus.ready: + img_url = result["result"]["sample"] + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {img_url}", node_id + ) + async with session.get(img_url) as img_resp: + return process_image_response(await img_resp.content.read()) + elif result["status"] in [ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + ]: + status = result["status"] + raise Exception( + f"BFL API did not return an image due to: {status}." + ) + elif result["status"] == BFLStatus.error: + raise Exception(f"BFL API encountered an error: {result}.") + elif result["status"] == BFLStatus.pending: + await asyncio.sleep(retry_pending_seconds) + continue + elif response.status == 404: + if retries_404 < max_retries_404: + retries_404 += 1 + await asyncio.sleep(retry_404_seconds) + continue + raise Exception( + f"BFL API could not find task after {max_retries_404} tries." + ) + elif response.status == 202: + await asyncio.sleep(retry_202_seconds) + elif time.time() - start_time > timeout: + raise Exception( + f"BFL API experienced a timeout; could not return request under {timeout} seconds." + ) + else: + raise Exception(f"BFL API encountered an error: {response.json()}") def convert_image_to_base64(image: torch.Tensor): scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) @@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC): BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, prompt_upsampling, @@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index af33279d5..3751fb2a1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -303,7 +303,7 @@ class GeminiNode(ComfyNodeABC): """ return GeminiPart(text=text) - def api_call( + async def api_call( self, prompt: str, model: GeminiModel, @@ -332,7 +332,7 @@ class GeminiNode(ComfyNodeABC): parts.extend(files) # Create response - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=get_gemini_endpoint(model), request=GeminiGenerateContentRequest( contents=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b8487355f..db24e6da4 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -212,7 +212,7 @@ V3_RESOLUTIONS= [ "1536x640" ] -def download_and_process_images(image_urls): +async def download_and_process_images(image_urls): """Helper function to download and process multiple images from URLs""" # Initialize list to store image tensors @@ -220,7 +220,7 @@ def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV2(ComfyNodeABC): @@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV3(ComfyNodeABC): """ @@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, image=None, @@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC): ) # Execute the operation and process response - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 69e9e5cf0..9d9eb5628 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -109,7 +109,7 @@ class KlingApiError(Exception): pass -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, @@ -117,7 +117,7 @@ def poll_until_finished( node_id: Optional[str] = None, ) -> R: """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ KlingTaskStatus.succeed.value, @@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]: return None -def video_result_to_node_output( +async def video_result_to_node_output( video: KlingVideoResult, ) -> tuple[VideoFromFile, str, str]: """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" return ( - download_url_to_video_output(video.url), + await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration), ) -def image_result_to_node_output( +async def image_result_to_node_output( images: list[KlingImageResult], ) -> torch.Tensor: """ @@ -297,9 +297,9 @@ def image_result_to_node_output( If multiple images are returned, they will be stacked along the batch dimension. """ if len(images) == 1: - return download_url_to_image_tensor(images[0].url) + return await download_url_to_image_tensor(str(images[0].url)) else: - return torch.cat([download_url_to_image_tensor(image.url) for image in images]) + return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images]) class KlingNodeBase(ComfyNodeABC): @@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Text to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingText2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", @@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlT2VNode(KlingTextToVideoNode): @@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, mode=KlingVideoGenMode.std, @@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Image to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingImage2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", @@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlI2VNode(KlingImage2VideoNode): @@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." - def api_call( + async def api_call( self, start_frame: torch.Tensor, end_frame: torch.Tensor, @@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ mode ] - return super().api_call( + return await super().api_call( prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoExtendResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EXTEND}/{task_id}", @@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingVideoEffectsBase(KlingNodeBase): @@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoEffectsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EFFECTS}/{task_id}", @@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, @@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): @@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): RETURN_TYPES = ("VIDEO", "STRING") RETURN_NAMES = ("VIDEO", "duration") - def api_call( + async def api_call( self, image_left: torch.Tensor, image_right: torch.Tensor, @@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): unique_id: Optional[str] = None, **kwargs, ): - video, _, duration = super().api_call( + video, _, duration = await super().api_call( dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." - def api_call( + async def api_call( self, image: torch.Tensor, effect_scene: KlingSingleImageEffectsScene, @@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase): f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." ) - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingLipSyncResponse: """Polls the Kling API endpoint until the task reaches a terminal state.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_LIP_SYNC}/{task_id}", @@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, video: VideoInput, audio: Optional[AudioInput] = None, @@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase): self.validate_lip_sync_video(video) # Upload video to Comfy API and get download URL - video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) + audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None @@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): @@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, audio: AudioInput, @@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( video=video, audio=audio, voice_language=voice_language, @@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, text: str, @@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): **kwargs, ): voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] - return super().api_call( + return await super().api_call( video=video, text=text, voice_language=voice_language, @@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVirtualTryOnResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", @@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): node_id=node_id, ) - def api_call( + async def api_call( self, human_image: torch.Tensor, cloth_image: torch.Tensor, @@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) class KlingImageGenerationNode(KlingImageGenerationBase): @@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase): DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: Optional[dict[str, str]], node_id: Optional[str] = None, ) -> KlingImageGenerationsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", @@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase): node_id=node_id, ) - def api_call( + async def api_call( self, model_name: KlingImageGenModelName, prompt: str, @@ -1714,17 +1714,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 525dc38e6..b3c32bed5 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import ( ) from server import PromptServer -import requests +import aiohttp import torch from io import BytesIO @@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC): # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = self._convert_luma_refs( + api_image_ref = await self._convert_luma_refs( image_luma_ref, max_refs=4, auth_kwargs=kwargs, ) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = self._convert_style_image( + api_style_ref = await self._convert_style_image( style_image, weight=style_image_weight, auth_kwargs=kwargs, ) # handle character_ref images character_ref = None if character_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( character_image, max_images=4, auth_kwargs=kwargs, ) character_ref = LumaCharacterRef( @@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC): node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) - def _convert_luma_refs( + async def _convert_luma_refs( self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None ): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( ref.image, max_images=1, auth_kwargs=auth_kwargs ) luma_urls.append(download_urls[0]) @@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC): break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) - def _convert_style_image( + async def _convert_style_image( self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None ): chain = LumaReferenceChain( first_ref=LumaReference(image=style_image, weight=weight) ) - return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) class LumaImageModifyNode(ComfyNodeABC): @@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC): **kwargs, ): # first, upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, ) image_url = download_urls[0] @@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC): node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) @@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): estimated_duration=LUMA_T2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class LumaImageToVideoGenerationNode(ComfyNodeABC): @@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): raise Exception( "At least one of first_image and last_image requires an input." ) - keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None @@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): estimated_duration=LUMA_I2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) - def _convert_to_keyframes( + async def _convert_to_keyframes( self, first_image: torch.Tensor = None, last_image: torch.Tensor = None, @@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): frame0 = None frame1 = None if first_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( first_image, max_images=1, auth_kwargs=auth_kwargs, ) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( last_image, max_images=1, auth_kwargs=auth_kwargs, ) frame1 = LumaImageReference(type="image", url=download_urls[0]) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 9b46636db..58d2ed90c 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -86,7 +86,7 @@ class MinimaxTextToVideoNode: API_NODE = True OUTPUT_NODE = True - def generate_video( + async def generate_video( self, prompt_text, seed=0, @@ -104,12 +104,12 @@ class MinimaxTextToVideoNode: # upload image, if passed in image_url = None if image is not None: - image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0] + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0] + subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] @@ -130,7 +130,7 @@ class MinimaxTextToVideoNode: ), auth_kwargs=kwargs, ) - response = video_generate_operation.execute() + response = await video_generate_operation.execute() task_id = response.task_id if not task_id: @@ -151,7 +151,7 @@ class MinimaxTextToVideoNode: node_id=unique_id, auth_kwargs=kwargs, ) - task_result = video_generate_operation.execute() + task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: @@ -167,7 +167,7 @@ class MinimaxTextToVideoNode: request=EmptyRequest(), auth_kwargs=kwargs, ) - file_result = file_retrieve_operation.execute() + file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: @@ -182,7 +182,7 @@ class MinimaxTextToVideoNode: message = f"Result URL: {file_url}" PromptServer.instance.send_progress_text(message, unique_id) - video_io = download_url_to_bytesio(file_url) + video_io = await download_url_to_bytesio(file_url) if video_io is None: error_msg = f"Failed to download video from {file_url}" logging.error(error_msg) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 789fcef02..164ca3ea5 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]: return None -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, node_id: Optional[str] = None, ) -> R: """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ "completed", @@ -394,10 +394,10 @@ class BaseMoonvalleyVideoNode: else: return control_map["Motion Transfer"] - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> MoonvalleyPromptResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{API_PROMPTS_ENDPOINT}/{task_id}", @@ -507,7 +507,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): RETURN_NAMES = ("video",) DESCRIPTION = "Moonvalley Marey Image to Video Node" - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): image = kwargs.get("image", None) @@ -532,9 +532,9 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = upload_images_to_comfyapi( + image_url = (await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - )[0] + ))[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -549,14 +549,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -609,7 +609,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): RETURN_TYPES = ("VIDEO",) RETURN_NAMES = ("video",) - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") @@ -620,7 +620,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -658,15 +658,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -688,7 +688,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): del input_types["optional"][param] return input_types - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) @@ -717,15 +717,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index be1d2de4a..ab3c5363b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -537,9 +537,9 @@ class OpenAIGPTImage1(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -623,7 +623,7 @@ class OpenAIChatNode(OpenAITextNode): DESCRIPTION = "Generate text responses from an OpenAI model." - def get_result_response( + async def get_result_response( self, response_id: str, include: Optional[list[Includable]] = None, @@ -639,7 +639,7 @@ class OpenAIChatNode(OpenAITextNode): creation above for more information. """ - return PollingOperation( + return await PollingOperation( poll_endpoint=ApiEndpoint( path=f"{RESPONSES_ENDPOINT}/{response_id}", method=HttpMethod.GET, @@ -784,7 +784,7 @@ class OpenAIChatNode(OpenAITextNode): self.history[session_id] = new_history - def api_call( + async def api_call( self, prompt: str, persist_context: bool, @@ -815,7 +815,7 @@ class OpenAIChatNode(OpenAITextNode): previous_response_id = None # Create response - create_response = SynchronousOperation( + create_response = await SynchronousOperation( endpoint=ApiEndpoint( path=RESPONSES_ENDPOINT, method=HttpMethod.POST, @@ -848,7 +848,7 @@ class OpenAIChatNode(OpenAITextNode): response_id = create_response.id # Get result output - result_response = self.get_result_response(response_id, auth_kwargs=kwargs) + result_response = await self.get_result_response(response_id, auth_kwargs=kwargs) output_text = self.parse_output_text_from_response(result_response) # Update history diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 1cc708564..a8dc43cb3 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC): FUNCTION = "api_call" RETURN_TYPES = ("VIDEO",) - def poll_for_task_status( + async def poll_for_task_status( self, task_id: str, auth_kwargs: Optional[dict[str, str]] = None, @@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC): node_id=node_id, estimated_duration=60 ) - return polling_operation.execute() + return await polling_operation.execute() - def execute_task( + async def execute_task( self, initial_operation: SynchronousOperation[R, PikaGenerateResponse], auth_kwargs: Optional[dict[str, str]] = None, @@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC): Returns: A tuple containing the video file as a VIDEO output. """ - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() if not is_valid_initial_response(initial_response): error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" logging.error(error_msg) raise PikaApiError(error_msg) task_id = initial_response.video_id - final_response = self.poll_for_task_status(task_id, auth_kwargs) + final_response = await self.poll_for_task_status(task_id, auth_kwargs) if not is_valid_video_response(final_response): error_msg = ( f"Pika task {task_id} succeeded but no video data found in response." @@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC): video_url = str(final_response.url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class PikaImageToVideoV2_2(PikaNodeBase): @@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, image: torch.Tensor, prompt_text: str, @@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaTextToVideoNodeV2_2(PikaNodeBase): @@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): content_type="application/x-www-form-urlencoded", ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaScenesV2_2(PikaNodeBase): @@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase): DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikAdditionsNode(PikaNodeBase): @@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase): DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase): image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( @@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaSwapsNode(PikaNodeBase): @@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase): DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." RETURN_TYPES = ("VIDEO",) - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase): image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + "modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( @@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaffectsNode(PikaNodeBase): @@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase): DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" - def api_call( + async def api_call( self, image: torch.Tensor, pikaffect: str, @@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaStartEndFrameNode2_2(PikaNodeBase): @@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." - def api_call( + async def api_call( self, image_start: torch.Tensor, image_end: torch.Tensor, @@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): ) -> tuple[VideoFromFile]: pika_files = [ - ( - "keyFrames", - ("image_start.png", tensor_to_bytesio(image_start), "image/png"), - ), + ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] @@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index ef4a9a802..7c5a52feb 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl import VideoFromFile import torch -import requests +import aiohttp from io import BytesIO @@ -47,7 +47,7 @@ def get_video_url_from_response( return str(response.Resp.url) -def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): +async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( @@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = operation.execute() + response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: raise Exception( @@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseImageToVideoNode(ComfyNodeABC): @@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): **kwargs, ): validate_string(prompt, strip_whitespace=False) - img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs) + img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseTransitionVideoNode(ComfyNodeABC): @@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, first_frame: torch.Tensor, last_frame: torch.Tensor, @@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): **kwargs, ): validate_string(prompt, strip_whitespace=False) - first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) + last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e369c4b7e..c8516b368 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -37,7 +37,7 @@ from io import BytesIO from PIL import UnidentifiedImageError -def handle_recraft_file_request( +async def handle_recraft_file_request( image: torch.Tensor, path: str, mask: torch.Tensor=None, @@ -71,13 +71,13 @@ def handle_recraft_file_request( auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -395,7 +395,7 @@ class RecraftTextToImageNode: }, } - def api_call( + async def api_call( self, prompt: str, size: str, @@ -439,7 +439,7 @@ class RecraftTextToImageNode: ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() images = [] urls = [] for data in response.data: @@ -451,7 +451,7 @@ class RecraftTextToImageNode: f"Result URL: {urls_string}", unique_id ) image = bytesio_to_image_tensor( - download_url_to_bytesio(data.url, timeout=1024) + await download_url_to_bytesio(data.url, timeout=1024) ) if len(image.shape) < 4: image = image.unsqueeze(0) @@ -538,7 +538,7 @@ class RecraftImageToImageNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -578,7 +578,7 @@ class RecraftImageToImageNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/imageToImage", request=request, @@ -654,7 +654,7 @@ class RecraftImageInpaintingNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -690,7 +690,7 @@ class RecraftImageInpaintingNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], mask=mask[i:i+1], path="/proxy/recraft/images/inpaint", @@ -779,7 +779,7 @@ class RecraftTextToVectorNode: }, } - def api_call( + async def api_call( self, prompt: str, substyle: str, @@ -821,7 +821,7 @@ class RecraftTextToVectorNode: ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] urls = [] for data in response.data: @@ -831,7 +831,7 @@ class RecraftTextToVectorNode: PromptServer.instance.send_progress_text( f"Result URL: {' '.join(urls)}", unique_id ) - svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) return (SVG(svg_data),) @@ -861,7 +861,7 @@ class RecraftVectorizeImageNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -870,7 +870,7 @@ class RecraftVectorizeImageNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/vectorize", auth_kwargs=kwargs, @@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, @@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/removeBackground", auth_kwargs=kwargs, @@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path=self.RECRAFT_PATH, auth_kwargs=kwargs, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 67f90478c..c89d087e5 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -9,11 +9,10 @@ from __future__ import annotations from inspect import cleandoc from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths -import requests +import aiohttp import os import datetime -import shutil -import time +import asyncio import io import logging import math @@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse): return hasattr(response, "error") - class Rodin3DAPI: """ Generate 3D Assets using Rodin API @@ -123,8 +121,8 @@ class Rodin3DAPI: else: return "Generating" - def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): - if images == None: + async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) >= 5: raise Exception("Rodin 3D generate requires up to 5 image.") @@ -155,7 +153,7 @@ class Rodin3DAPI: auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if create_task_error(response): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" @@ -168,7 +166,7 @@ class Rodin3DAPI: logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") return task_uuid, subscription_key - def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: + async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: path = "/proxy/rodin/api/v2/status" @@ -191,11 +189,9 @@ class Rodin3DAPI: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return poll_operation.execute() + return await poll_operation.execute() - - - def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: + async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") path = "/proxy/rodin/api/v2/download" @@ -212,53 +208,59 @@ class Rodin3DAPI: auth_kwargs=kwargs ) - return operation.execute() + return await operation.execute() - def GetQualityAndMode(self, PolyCount): - if PolyCount == "200K-Triangle": + def get_quality_mode(self, poly_count): + if poly_count == "200K-Triangle": mesh_mode = "Raw" quality = "medium" else: mesh_mode = "Quad" - if PolyCount == "4K-Quad": + if poly_count == "4K-Quad": quality = "extra-low" - elif PolyCount == "8K-Quad": + elif poly_count == "8K-Quad": quality = "low" - elif PolyCount == "18K-Quad": + elif poly_count == "18K-Quad": quality = "medium" - elif PolyCount == "50K-Quad": + elif poly_count == "50K-Quad": quality = "high" else: quality = "medium" return mesh_mode, quality - def DownLoadFiles(self, Url_List): - Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - os.makedirs(Save_path, exist_ok=True) + async def download_files(self, url_list): + save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + os.makedirs(save_path, exist_ok=True) model_file_path = None - for Item in Url_List.list: - url = Item.url - file_name = Item.name - file_path = os.path.join(Save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - time.sleep(2) - else: - logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) return model_file_path @@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Detail(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Smooth(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Sketch(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI): material_type = "PBR" quality = "medium" mesh_mode = "Quad" - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + task_uuid, subscription_key = await self.create_generate_task( + images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs + ) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index af4b321f9..98024a9fa 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool: return image.shape[2] < 8000 and image.shape[1] < 8000 -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, TaskStatusResponse], estimated_duration: Optional[int] = None, node_id: Optional[str] = None, ) -> TaskStatusResponse: """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ TaskStatus.SUCCEEDED.value, @@ -115,7 +115,7 @@ def poll_until_finished( TaskStatus.FAILED.value, TaskStatus.CANCELLED.value, ], - status_extractor=lambda response: (response.status.value), + status_extractor=lambda response: response.status.value, auth_kwargs=auth_kwargs, result_url_extractor=get_video_url_from_task_status, estimated_duration=estimated_duration, @@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC): ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC): node_id=node_id, ) - def generate_video( + async def generate_video( self, request: RunwayImageToVideoRequest, auth_kwargs: dict[str, str], @@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC): auth_kwargs=auth_kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id - final_response = self.get_response(task_id, auth_kwargs, node_id) + final_response = await self.get_response(task_id, auth_kwargs, node_id) self.validate_response(final_response) video_url = get_video_url_from_task_status(final_response) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): @@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( stacked_input_images, max_images=2, mime_type="image/png", @@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC): ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, ratio: str, @@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC): reference_images = None if reference_image is not None: validate_input_image(reference_image) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( reference_image, max_images=1, mime_type="image/png", @@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC): auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id # Poll for completion - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) self.validate_response(final_response) # Download and return image image_url = get_image_url_from_task_status(final_response) - return (download_url_to_image_tensor(image_url),) + return (await download_url_to_image_tensor(image_url),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 02e421678..31309d831 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -124,7 +124,7 @@ class StabilityStableImageUltraNode: }, } - def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, + async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -163,7 +163,7 @@ class StabilityStableImageUltraNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node: }, } - def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, + async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode: }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode: }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode: status_extractor=lambda x: get_async_dummy_status(x), auth_kwargs=kwargs, ) - response_poll: StabilityResultsGetResponse = operation.execute() + response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -563,8 +563,7 @@ class StabilityUpscaleFastNode: }, } - def api_call(self, image: torch.Tensor, - **kwargs): + async def api_call(self, image: torch.Tensor, **kwargs): image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { @@ -583,7 +582,7 @@ class StabilityUpscaleFastNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 65f3b21f5..d08cf9007 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import ( ) -def upload_image_to_tripo(image, **kwargs): - urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) +async def upload_image_to_tripo(image, **kwargs): + urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: raise RuntimeError(f"Failed to get model url from response: {response}") -def poll_until_finished( +async def poll_until_finished( kwargs: dict[str, str], response: TripoTaskResponse, ) -> tuple[str, str]: @@ -57,7 +57,7 @@ def poll_until_finished( if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = PollingOperation( + response_poll = await PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/tripo/v2/openapi/task/{task_id}", method=HttpMethod.GET, @@ -80,7 +80,7 @@ def poll_until_finished( ).execute() if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = download_url_to_bytesio(url) + bytesio = await download_url_to_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: @@ -88,6 +88,7 @@ def poll_until_finished( return model_file, task_id raise RuntimeError(f"Failed to generate mesh: {response_poll}") + class TripoTextToModelNode: """ Generates 3D models synchronously based on a text prompt using Tripo's API. @@ -126,11 +127,11 @@ class TripoTextToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -155,7 +156,8 @@ class TripoTextToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoImageToModelNode: """ @@ -195,12 +197,12 @@ class TripoImageToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = upload_image_to_tripo(image, **kwargs) - response = SynchronousOperation( + tripo_file = await upload_image_to_tripo(image, **kwargs) + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -225,7 +227,8 @@ class TripoImageToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoMultiviewToModelNode: """ @@ -267,7 +270,7 @@ class TripoMultiviewToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): if image is None: raise RuntimeError("front image for multiview is required") images = [] @@ -282,11 +285,11 @@ class TripoMultiviewToModelNode: for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = upload_image_to_tripo(image_, **kwargs) + tripo_file = await upload_image_to_tripo(image_, **kwargs) images.append(tripo_file) else: images.append(TripoFileEmptyReference()) - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -309,7 +312,8 @@ class TripoMultiviewToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoTextureNode: @classmethod @@ -340,8 +344,8 @@ class TripoTextureNode: OUTPUT_NODE = True AVERAGE_DURATION = 80 - def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -358,7 +362,7 @@ class TripoTextureNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRefineNode: @@ -387,8 +391,8 @@ class TripoRefineNode: OUTPUT_NODE = True AVERAGE_DURATION = 240 - def generate_mesh(self, model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -400,7 +404,7 @@ class TripoRefineNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRigNode: @@ -425,8 +429,8 @@ class TripoRigNode: OUTPUT_NODE = True AVERAGE_DURATION = 180 - def generate_mesh(self, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -440,7 +444,8 @@ class TripoRigNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoRetargetNode: @classmethod @@ -475,8 +480,8 @@ class TripoRetargetNode: OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, animation, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -491,7 +496,8 @@ class TripoRetargetNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoConversionNode: @classmethod @@ -529,10 +535,10 @@ class TripoConversionNode: OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -549,7 +555,8 @@ class TripoConversionNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + NODE_CLASS_MAPPINGS = { "TripoTextToModelNode": TripoTextToModelNode, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 97bfe20e6..e25dab2f5 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,7 +1,7 @@ import io import logging import base64 -import requests +import aiohttp import torch from typing import Optional @@ -152,7 +152,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" API_NODE = True - def generate_video( + async def generate_video( self, prompt, aspect_ratio="16:9", @@ -217,7 +217,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() operation_name = initial_response.name logging.info(f"Veo generation started with operation name: {operation_name}") @@ -256,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): ) # Execute the polling operation - poll_response = poll_operation.execute() + poll_response = await poll_operation.execute() # Now check for errors in the final response # Check for error in poll response @@ -281,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC): raise Exception(error_message) # Extract video data - video_data = None if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: video = poll_response.response.videos[0] @@ -291,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC): video_data = base64.b64decode(video.bytesBase64Encoded) elif hasattr(video, 'gcsUri') and video.gcsUri: # Download from URL - video_url = video.gcsUri - video_response = requests.get(video_url) - video_data = video_response.content + async with aiohttp.ClientSession() as session: + async with session.get(video.gcsUri) as video_response: + video_data = await video_response.content.read() else: raise Exception("Video returned but no data or URL was provided") else: From 735bb4bdb186bd4f39b9c924c24b8b39a7ef8b0d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 8 Aug 2025 01:21:00 -0700 Subject: [PATCH 033/325] Users report gfx1201 is buggy on flux with pytorch attention. (#9244) --- comfy/model_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9e6149d60..dc5b4711d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -321,9 +321,9 @@ try: if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True - if torch_version_numeric >= (2, 8): - if any((a in arch) for a in ["gfx1201"]): - ENABLE_PYTORCH_ATTENTION = True +# if torch_version_numeric >= (2, 8): +# if any((a in arch) for a in ["gfx1201"]): +# ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True From 5828607ccfef82a82931d8b66f3fd1176e04588f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 9 Aug 2025 09:49:25 -0700 Subject: [PATCH 034/325] Not sure if AMD actually support fp16 acc but it doesn't crash. (#9258) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dc5b4711d..c08f759e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -340,7 +340,7 @@ if ENABLE_PYTORCH_ATTENTION: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast: + if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance logging.info("Enabled fp16 accumulation.") From 0552de7c7d6bcdd515da115d6756fd30494c7ff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 10 Aug 2025 02:03:47 -0700 Subject: [PATCH 035/325] Bump pytorch cuda and rocm versions in readme instructions. (#9273) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 119098f5c..e4cff01a9 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ Put your VAE in: models/vae ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: @@ -237,7 +237,7 @@ Additional discussion and help can be found [here](https://github.com/comfyanony Nvidia users should install stable pytorch using this command: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128``` +```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129``` This is the command to install pytorch nightly instead which might have performance improvements. From 966f3a52061b5e300f36c6de0d07c47d6ad12f76 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:53:01 -0700 Subject: [PATCH 036/325] Only show feature flags log when verbose. (#9281) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 0553a0dd7..8f9c88ebf 100644 --- a/server.py +++ b/server.py @@ -235,7 +235,7 @@ class PromptServer(): sid, ) - logging.info( + logging.debug( f"Feature flags negotiated for client {sid}: {client_flags}" ) first_message = False From fa340add552497a264071fd7f6c407ff4aa10449 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 11 Aug 2025 23:48:17 +0300 Subject: [PATCH 037/325] remove creation of non-used asyncio_loop (#9284) --- execution.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/execution.py b/execution.py index 952f0cc5c..1dc35738b 100644 --- a/execution.py +++ b/execution.py @@ -646,8 +646,6 @@ class PromptExecutor: self.add_message("execution_error", mes, broadcast=False) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - asyncio_loop = asyncio.new_event_loop() - asyncio.set_event_loop(asyncio_loop) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): From 629b17383718e1f46dbba101ea83ec897fbe3082 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 12 Aug 2025 04:52:12 +0800 Subject: [PATCH 038/325] Update template & embedded docs (#9283) * Update template & embedded docs * Update embedded docs to 0.2.6 --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2f4692b03..2fb38ef27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.52 -comfyui-embedded-docs==0.2.4 +comfyui-workflow-templates==0.1.53 +comfyui-embedded-docs==0.2.6 torch torchsde torchvision From 2208aa616d3ad193cd37ef57076d4f5243cecdd3 Mon Sep 17 00:00:00 2001 From: PsychoLogicAu Date: Tue, 12 Aug 2025 06:56:16 +1000 Subject: [PATCH 039/325] Support SimpleTuner lycoris lora for Qwen-Image (#9280) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 6686b7229..00358884b 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -301,6 +301,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["{}".format(key_lora)] = k # Support transformer prefix format key_map["transformer.{}".format(key_lora)] = k + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format return key_map From f4231a80b1b904b45ade0def9b37320c4adfe71b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 12 Aug 2025 00:15:14 +0300 Subject: [PATCH 040/325] fix(Kling Image API Node): do not pass "image_type" when no image (#9271) * fix(Kling Image API Node): do not pass "image_type" when no image * fix(Kling Image API Node): raise client-side error when kling_v1 is used with reference image --- comfy_api_nodes/nodes_kling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9d9eb5628..9d483bb0e 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase): ): self.validate_prompt(prompt, negative_prompt) - if image is not None: + if image is None: + image_type = None + elif model_name == KlingImageGenModelName.kling_v1: + raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") + else: image = tensor_to_base64_string(image) initial_operation = SynchronousOperation( From 1e3ae1eed8b925430e3b114ea6b7d08ea698e305 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 05:14:27 +0800 Subject: [PATCH 041/325] Update template to 0.1.58 (#9302) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2fb38ef27..82af5690b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.53 +comfyui-workflow-templates==0.1.58 comfyui-embedded-docs==0.2.6 torch torchsde From e1d4f36d8df7446ebb1a5f2bf9c708c38a159f22 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:13:04 -0700 Subject: [PATCH 042/325] Update test release package workflow with python 3.13 cu129. (#9306) --- .github/workflows/windows_release_package.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3926a65f3..b51746285 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,19 +7,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master @@ -64,6 +64,8 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From 560d38f34c5bd532f89f2178f01ee819cf145820 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 20:26:33 -0700 Subject: [PATCH 043/325] Wan2.2 fun control support. (#9292) --- comfy/ldm/wan/model.py | 19 +++++++++++++ comfy/model_base.py | 10 ++++++- comfy/model_detection.py | 5 ++++ comfy_extras/nodes_wan.py | 58 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 86d0795e9..4e2d99566 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -391,6 +391,7 @@ class WanModel(torch.nn.Module): cross_attn_norm=True, eps=1e-6, flf_pos_embed_token_number=None, + in_dim_ref_conv=None, image_model=None, device=None, dtype=None, @@ -484,6 +485,11 @@ class WanModel(torch.nn.Module): else: self.img_emb = None + if in_dim_ref_conv is not None: + self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + else: + self.ref_conv = None + def forward_orig( self, x, @@ -526,6 +532,13 @@ class WanModel(torch.nn.Module): e = e.reshape(t.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + # context context = self.text_embedding(context) @@ -552,6 +565,9 @@ class WanModel(torch.nn.Module): # head x = self.head(x, e) + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + # unpatchify x = self.unpatchify(x, grid_sizes) return x @@ -570,6 +586,9 @@ class WanModel(torch.nn.Module): x = torch.cat([x, time_dim_concat], dim=2) t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + if self.ref_conv is not None and "reference_latent" in kwargs: + t_len += 1 + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a2d9cbe6..cde61df7c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1124,7 +1124,11 @@ class WAN21(BaseModel): mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((mask, image), dim=1) + concat_mask_index = kwargs.get("concat_mask_index", 0) + if concat_mask_index != 0: + return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) + else: + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1140,6 +1144,10 @@ class WAN21(BaseModel): if time_dim_concat is not None: out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8b57ebd2f..8acc51e20 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -373,6 +373,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + + ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) + if ref_conv_weight is not None: + dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0067d054d..f80c83ba6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -103,6 +103,63 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class Wan22FunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"ref_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + # "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + ref_latent = None + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + class WanFirstLastFrameToVideo: @classmethod def INPUT_TYPES(s): @@ -733,6 +790,7 @@ NODE_CLASS_MAPPINGS = { "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "Wan22FunControlToVideo": Wan22FunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, From 898d88e10e45f38500ca6044280bab4ca2f2d273 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Aug 2025 20:34:58 -0700 Subject: [PATCH 044/325] Make torchaudio exception catching less specific (#9309) --- comfy_api/latest/_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 61597038f..26a55615f 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -12,7 +12,7 @@ import torch try: import torchaudio TORCH_AUDIO_AVAILABLE = True -except ImportError: +except: TORCH_AUDIO_AVAILABLE = False from PIL import Image as PILImage from PIL.PngImagePlugin import PngInfo From 3294782d19c3af0c6166aafe0465fe6b59571d17 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 14:50:50 +0800 Subject: [PATCH 045/325] Update template to 0.1.59 (#9313) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 82af5690b..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.58 +comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch torchsde From 5ca8e2fac3b6826261c5991b0663b69eff60b3a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:01:12 -0700 Subject: [PATCH 046/325] Update release workflow to python3.13 pytorch cu129 (#9315) * Try to reduce size of portable even more. * Update stable release workflow to python 3.13 cu129 * Update dependencies workflow to python3.13 cu129 --- .github/workflows/stable-release.yml | 15 ++++++++++----- .../workflows/windows_release_dependencies.yml | 6 +++--- .github/workflows/windows_release_package.yml | 2 ++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 61105abe4..a5a1ed2d0 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -12,17 +12,17 @@ on: description: 'CUDA version' required: true type: string - default: "128" + default: "129" python_minor: description: 'Python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'Python patch version' required: true type: string - default: "10" + default: "6" jobs: @@ -66,8 +66,13 @@ jobs: curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* - sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - cd .. + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + + cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index dfdb96d50..7761cc1ed 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,19 +17,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index b51746285..3334e6839 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -66,6 +66,8 @@ jobs: sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From e400f26c8fc9867248394616a4b58ecc4c53fbfd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:44:54 -0700 Subject: [PATCH 047/325] Downgrade frontend for release. (#9316) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bfb31a73f..56ed85e01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.23.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From d5c1954d5cd4a789bbf84d2b75a955a5a3f93de8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Aug 2025 03:46:38 -0400 Subject: [PATCH 048/325] ComfyUI version 0.3.50 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5e2d09c81..29ec07ca6 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.49" +__version__ = "0.3.50" diff --git a/pyproject.toml b/pyproject.toml index 3c530ba85..659b5730a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.49" +version = "0.3.50" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 615eb52049df98cebdd67bc672b66dc059171d7c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:48:06 -0700 Subject: [PATCH 049/325] Put back frontend version. (#9317) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 56ed85e01..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.23.4 +comfyui-frontend-package==1.24.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From afa0a45206832b0e64e38454b7841d1da7ca56e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:42:08 -0700 Subject: [PATCH 050/325] Reduce portable size again. (#9323) * compress more * test * not needed --- .github/workflows/stable-release.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index a5a1ed2d0..2bc8e5905 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -90,7 +90,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z cd ComfyUI_windows_portable diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3334e6839..46375698e 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -86,7 +86,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z cd ComfyUI_windows_portable From 3da5a07510794c37d437cbea1d94065bb0aa8ebc Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 13 Aug 2025 20:53:27 +0200 Subject: [PATCH 051/325] SDPA backend priority (#9299) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 13 +++++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 5eb2c6548..bea6090a2 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = F.scaled_dot_product_attention(q, k, v) + out = ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..19c3c7af1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5c0373b74..79160412f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 2cc9bbc27..8b7b662b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,9 +23,18 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +from torch.nn.attention import SDPBackend, sdpa_kernel cast_to = comfy.model_management.cast_to #TODO: remove once no more references +SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, +] +if torch.cuda.is_available(): + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -249,6 +258,10 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") + @staticmethod + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): From 9df8792d4b894a8ea8034414ef63f70deee4b1af Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:12:41 -0700 Subject: [PATCH 052/325] Make last PR not crash comfy on old pytorch. (#9324) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 36 +++++++++++++-------- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index bea6090a2..6e8cbf1d9 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = ops.scaled_dot_product_attention(q, k, v) + out = comfy.ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 19c3c7af1..043df28df 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 79160412f..1fd12b35a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 8b7b662b6..be312d714 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,18 +23,32 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib -from torch.nn.attention import SDPBackend, sdpa_kernel + + +def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + + +try: + if torch.cuda.is_available(): + from torch.nn.attention import SDPBackend, sdpa_kernel + + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) +except (ModuleNotFoundError, TypeError): + logging.warning("Could not set sdpa backend priority.") cast_to = comfy.model_management.cast_to #TODO: remove once no more references -SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, -] -if torch.cuda.is_available(): - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -258,10 +272,6 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") - @staticmethod - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): From c991a5da658667cf29f2916bef096fa7b18afd47 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Wed, 13 Aug 2025 16:13:35 -0700 Subject: [PATCH 053/325] Fix XPU iGPU regressions (#9322) * Change bf16 check and switch non-blocking to off default with option to force to regain speed on certain classes of iGPUs and refactor xpu check. * Turn non_blocking off by default for xpu. * Update README.md for Intel GPUs. --- README.md | 28 ++++++++++------------------ comfy/cli_args.py | 2 ++ comfy/model_management.py | 21 +++++++++++++-------- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index e4cff01a9..fa99a8cbe 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a ## Get Started #### [Desktop Application](https://www.comfy.org/download) -- The easiest way to get started. +- The easiest way to get started. - Available on Windows & macOS. #### [Windows Portable Package](#installing) @@ -211,27 +211,19 @@ This is the command to install the nightly with ROCm 6.4 which might have some p ### Intel GPUs (Windows and Linux) -(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) - -1. To install PyTorch nightly, use the following command: +(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) + +1. To install PyTorch xpu, use the following command: + +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu``` + +This is the command to install the Pytorch xpu nightly which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu``` -2. Launch ComfyUI by running `python main.py` - - (Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance. -1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below: - -``` -conda install libuv -pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ -``` - -For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. - -Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476). +1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. ### NVIDIA @@ -352,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`. -> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above. +> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.

If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal. ## Support and dev channel diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0d760d524..de3e85c08 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") + parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") diff --git a/comfy/model_management.py b/comfy/model_management.py index c08f759e5..2a9f18068 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -78,7 +78,6 @@ try: torch_version = torch.version.__version__ temp = torch_version.split(".") torch_version_numeric = (int(temp[0]), int(temp[1])) - xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() except: pass @@ -102,10 +101,14 @@ if args.directml is not None: try: import intel_extension_for_pytorch as ipex # noqa: F401 - _ = torch.xpu.device_count() - xpu_available = xpu_available or torch.xpu.is_available() except: - xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) + pass + +try: + _ = torch.xpu.device_count() + xpu_available = torch.xpu.is_available() +except: + xpu_available = False try: if torch.backends.mps.is_available(): @@ -946,10 +949,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None): return dtype def device_supports_non_blocking(device): + if args.force_non_blocking: + return True if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking - if is_intel_xpu(): - return True + if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes + return False if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False if directml_enabled: @@ -1282,10 +1287,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 6): + if torch_version_numeric < (2, 3): return True else: - return torch.xpu.get_device_capability(device)['has_bfloat16_conversions'] + return torch.xpu.is_bf16_supported() if is_ascend_npu(): return True From e4f7ea105f4b3034593f316560d952b80453e344 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 13 Aug 2025 18:33:05 -0700 Subject: [PATCH 054/325] Added context window support to core sampling code (#9238) * Added initial support for basic context windows - in progress * Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher * Made context windows compatible with different dimensions; works for WAN, but results are bad * Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py * Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code * Made dim slicing cleaner * Add Wan Context WIndows node for testing * Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found * Moved some code around between node_context_windows.py and context_windows.py * Change manual context window nodes names/ids * Added callbacks to IndexListContexHandler * Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying * Make get_resized_cond more robust for various dim sizes * Fix typo * Another small fix --- comfy/context_windows.py | 537 ++++++++++++++++++++++++++ comfy/sampler_helpers.py | 6 +- comfy/samplers.py | 11 +- comfy_extras/nodes_context_windows.py | 89 +++++ nodes.py | 1 + 5 files changed, 639 insertions(+), 5 deletions(-) create mode 100644 comfy/context_windows.py create mode 100644 comfy_extras/nodes_context_windows.py diff --git a/comfy/context_windows.py b/comfy/context_windows.py new file mode 100644 index 000000000..928b111df --- /dev/null +++ b/comfy/context_windows.py @@ -0,0 +1,537 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable +import torch +import numpy as np +import collections +from dataclasses import dataclass +from abc import ABC, abstractmethod +import logging +import comfy.model_management +import comfy.patcher_extension +if TYPE_CHECKING: + from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher + from comfy.controlnet import ControlBase + + +class ContextWindowABC(ABC): + def __init__(self): + ... + + @abstractmethod + def get_tensor(self, full: torch.Tensor) -> torch.Tensor: + """ + Get torch.Tensor applicable to current window. + """ + raise NotImplementedError("Not implemented.") + + @abstractmethod + def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor: + """ + Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy. + """ + raise NotImplementedError("Not implemented.") + +class ContextHandlerABC(ABC): + def __init__(self): + ... + + @abstractmethod + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + raise NotImplementedError("Not implemented.") + + + +class IndexListContextWindow(ContextWindowABC): + def __init__(self, index_list: list[int], dim: int=0): + self.index_list = index_list + self.context_length = len(index_list) + self.dim = dim + + def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + if dim is None: + dim = self.dim + if dim == 0 and full.shape[dim] == 1: + return full + idx = [slice(None)] * dim + [self.index_list] + return full[idx].to(device) + + def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: + if dim is None: + dim = self.dim + idx = [slice(None)] * dim + [self.index_list] + full[idx] += to_add + return full + + +class IndexListCallbacks: + EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" + COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" + EXECUTE_START = "execute_start" + EXECUTE_CLEANUP = "execute_cleanup" + + def init_callbacks(self): + return {} + + +@dataclass +class ContextSchedule: + name: str + func: Callable + +@dataclass +class ContextFuseMethod: + name: str + func: Callable + +ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) +class IndexListContextHandler(ContextHandlerABC): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + self.context_schedule = context_schedule + self.fuse_method = fuse_method + self.context_length = context_length + self.context_overlap = context_overlap + self.context_stride = context_stride + self.closed_loop = closed_loop + self.dim = dim + self._step = 0 + + self.callbacks = {} + + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + # for now, assume first dim is batch - should have stored on BaseModel in actual implementation + if x_in.size(self.dim) > self.context_length: + logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + return True + return False + + def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: + if control.previous_controlnet is not None: + self.prepare_control_objects(control.previous_controlnet, device) + return control + + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list: + if cond_in is None: + return None + # reuse or resize cond items to match context requirements + resized_cond = [] + # cond object is a list containing a dict - outer list is irrelevant, so just loop through it + for actual_cond in cond_in: + resized_actual_cond = actual_cond.copy() + # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary + for key in actual_cond: + try: + cond_item = actual_cond[key] + if isinstance(cond_item, torch.Tensor): + # check that tensor is the expected length - x.size(0) + if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim): + # if so, it's subsetting time - tell controls the expected indeces so they can handle them + actual_cond_item = window.get_tensor(cond_item) + resized_actual_cond[key] = actual_cond_item.to(device) + else: + resized_actual_cond[key] = cond_item.to(device) + # look for control + elif key == "control": + resized_actual_cond[key] = self.prepare_control_objects(cond_item, device) + elif isinstance(cond_item, dict): + new_cond_item = cond_item.copy() + # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) + for cond_key, cond_value in new_cond_item.items(): + if isinstance(cond_value, torch.Tensor): + if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # if has cond that is a Tensor, check if needs to be subset + elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + elif cond_key == "num_video_frames": # for SVD + new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) + new_cond_item[cond_key].cond = window.context_length + resized_actual_cond[key] = new_cond_item + else: + resized_actual_cond[key] = cond_item + finally: + del cond_item # just in case to prevent VRAM issues + resized_cond.append(resized_actual_cond) + return resized_cond + + def set_step(self, timestep: torch.Tensor, model_options: dict[str]): + indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) + self._step = int(indexes[0]) + + def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: + full_length = x_in.size(self.dim) # TODO: choose dim based on model + context_windows = self.context_schedule.func(full_length, self, model_options) + context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + return context_windows + + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self.set_step(timestep, model_options) + context_windows = self.get_context_windows(model, x_in, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + + conds_final = [torch.zeros_like(x_in) for _ in conds] + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + else: + counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + for result in results: + self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, + conds_final, counts_final, biases_final) + try: + # finalize conds + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + # relative is already normalized, so return as is + del counts_final + return conds_final + else: + # normalize conds via division by context usage counts + for i in range(len(conds_final)): + conds_final[i] /= counts_final[i] + del counts_final + return conds_final + finally: + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, device=None, first_device=None): + results: list[ContextResults] = [] + for window_idx, window in enumerated_context_windows: + # allow processing to end between context window executions for faster Cancel + comfy.model_management.throw_exception_if_processing_interrupted() + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) + + # update exposed params + model_options["transformer_options"]["context_window"] = window + # get subsections of x, timestep, conds + sub_x = window.get_tensor(x_in, device) + sub_timestep = window.get_tensor(timestep, device, dim=0) + sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + + sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) + if device is not None: + for i in range(len(sub_conds_out)): + sub_conds_out[i] = sub_conds_out[i].to(x_in.device) + results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + return results + + + def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor, + conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + for pos, idx in enumerate(window.index_list): + # bias is the influence of a specific index in relation to the whole context window + bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2) + bias = max(1e-2, bias) + # take weighted average relative to total bias of current idx + for i in range(len(sub_conds_out)): + bias_total = biases_final[i][idx] + prev_weight = (bias_total / (bias_total + bias)) + new_weight = (bias / (bias_total + bias)) + # account for dims of tensors + idx_window = [slice(None)] * self.dim + [idx] + pos_window = [slice(None)] * self.dim + [pos] + # apply new values + conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight + biases_final[i][idx] = bias_total + bias + else: + # add conds and counts based on weights of fuse method + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) + for i in range(len(sub_conds_out)): + window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) + window.add_window(counts_final[i], weights_tensor) + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): + callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) + + +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): + # limit noise_shape length to context_length for more accurate vram use estimation + model_options = kwargs.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is not None: + noise_shape = list(noise_shape) + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, *args, **kwargs) + + +def create_prepare_sampling_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, + "ContextWindows_prepare_sampling", + _prepare_sampling_wrapper + ) + + +def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: + total_dims = len(x_in.shape) + weights_tensor = torch.Tensor(weights).to(device=device) + for _ in range(dim): + weights_tensor = weights_tensor.unsqueeze(0) + for _ in range(total_dims - dim - 1): + weights_tensor = weights_tensor.unsqueeze(-1) + return weights_tensor + +def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: + total_dims = len(x_in.shape) + shape = [] + for _ in range(dim): + shape.append(1) + shape.append(x_in.shape[dim]) + for _ in range(total_dims - dim - 1): + shape.append(1) + return shape + +class ContextSchedules: + UNIFORM_LOOPED = "looped_uniform" + UNIFORM_STANDARD = "standard_uniform" + STATIC_STANDARD = "standard_static" + BATCHED = "batched" + + +# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py +def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames < handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + return windows + +def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; + # instead, they get shifted to the corresponding end of the frames. + # in the case that a window (shifted or not) is identical to the previous one, it gets skipped. + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # first, obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (-handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + # now that windows are created, shift any windows that loop, and delete duplicate windows + delete_idxs = [] + win_i = 0 + while win_i < len(windows): + # if window is rolls over itself, need to shift it + is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) + if is_roll: + roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides + shift_window_to_end(windows[win_i], num_frames=num_frames) + # check if next window (cyclical) is missing roll_val + if roll_val not in windows[(win_i+1) % len(windows)]: + # need to insert new window here - just insert window starting at roll_val + windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length))) + # delete window if it's not unique + for pre_i in range(0, win_i): + if windows[win_i] == windows[pre_i]: + delete_idxs.append(win_i) + break + win_i += 1 + + # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation + delete_idxs.reverse() + for i in delete_idxs: + windows.pop(i) + + return windows + + +def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows + delta = handler.context_length - handler.context_overlap + for start_idx in range(0, num_frames, delta): + # if past the end of frames, move start_idx back to allow same context_length + ending = start_idx + handler.context_length + if ending >= num_frames: + final_delta = ending - num_frames + final_start_idx = start_idx - final_delta + windows.append(list(range(final_start_idx, final_start_idx + handler.context_length))) + break + windows.append(list(range(start_idx, start_idx + handler.context_length))) + return windows + + +def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows; + # no overlap, just cut up based on context_length; + # last window size will be different if num_frames % opts.context_length != 0 + for start_idx in range(0, num_frames, handler.context_length): + windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames)))) + return windows + + +def create_windows_default(num_frames: int, handler: IndexListContextHandler): + return [list(range(num_frames))] + + +CONTEXT_MAPPING = { + ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped, + ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard, + ContextSchedules.STATIC_STANDARD: create_windows_static_standard, + ContextSchedules.BATCHED: create_windows_batched, +} + + +def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: + func = CONTEXT_MAPPING.get(context_schedule, None) + if func is None: + raise ValueError(f"Unknown context_schedule '{context_schedule}'.") + return ContextSchedule(context_schedule, func) + + +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) + + +def create_weights_flat(length: int, **kwargs) -> list[float]: + # weight is the same for all + return [1.0] * length + +def create_weights_pyramid(length: int, **kwargs) -> list[float]: + # weight is based on the distance away from the edge of the context window; + # based on weighted average concept in FreeNoise paper + if length % 2 == 0: + max_weight = length // 2 + weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) + else: + max_weight = (length + 1) // 2 + weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) + return weight_sequence + +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): + # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 + # only expected overlap is given different weights + weights_torch = torch.ones((length)) + # blend left-side on all except first window + if min(idxs) > 0: + ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) + weights_torch[:handler.context_overlap] = ramp_up + # blend right-side on all except last window + if max(idxs) < full_length-1: + ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) + weights_torch[-handler.context_overlap:] = ramp_down + return weights_torch + +class ContextFuseMethods: + FLAT = "flat" + PYRAMID = "pyramid" + RELATIVE = "relative" + OVERLAP_LINEAR = "overlap-linear" + + LIST = [PYRAMID, FLAT, OVERLAP_LINEAR] + LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR] + + +FUSE_MAPPING = { + ContextFuseMethods.FLAT: create_weights_flat, + ContextFuseMethods.PYRAMID: create_weights_pyramid, + ContextFuseMethods.RELATIVE: create_weights_pyramid, + ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear, +} + +def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod: + func = FUSE_MAPPING.get(fuse_method, None) + if func is None: + raise ValueError(f"Unknown fuse_method '{fuse_method}'.") + return ContextFuseMethod(fuse_method, func) + +# Returns fraction that has denominator that is a power of 2 +def ordered_halving(val): + # get binary value, padded with 0s for 64 bits + bin_str = f"{val:064b}" + # flip binary value, padding included + bin_flip = bin_str[::-1] + # convert binary to int + as_int = int(bin_flip, 2) + # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616, + # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's) + return as_int / (1 << 64) + + +def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: + all_indexes = list(range(num_frames)) + for w in windows: + for val in w: + try: + all_indexes.remove(val) + except ValueError: + pass + return all_indexes + + +def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: + prev_val = -1 + for i, val in enumerate(window): + val = val % num_frames + if val < prev_val: + return True, i + prev_val = val + return False, -1 + + +def shift_window_to_start(window: list[int], num_frames: int): + start_val = window[0] + for i in range(len(window)): + # 1) subtract each element by start_val to move vals relative to the start of all frames + # 2) add num_frames and take modulus to get adjusted vals + window[i] = ((window[i] - start_val) + num_frames) % num_frames + + +def shift_window_to_end(window: list[int], num_frames: int): + # 1) shift window to start + shift_window_to_start(window, num_frames) + end_val = window[-1] + end_delta = num_frames - end_val - 1 + for i in range(len(window)): + # 2) add end_delta to each val to slide windows to end + window[i] = window[i] + end_delta diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 8dbc41455..e46971afb 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -149,7 +149,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options - model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) - model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False) # begin registering hooks registered = comfy.hooks.HookGroup() target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) diff --git a/comfy/samplers.py b/comfy/samplers.py index ad2f40cdc..d5390d64e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -16,6 +16,7 @@ import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks +import comfy.context_windows import scipy.stats import numpy @@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] -def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]): + handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None) + if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options): + return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options) + return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options) + +def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): out_conds = [] out_counts = [] # separate conds by matching hooks diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py new file mode 100644 index 000000000..1c3d9e697 --- /dev/null +++ b/comfy_extras/nodes_context_windows.py @@ -0,0 +1,89 @@ +from __future__ import annotations +from comfy_api.latest import ComfyExtension, io +import comfy.context_windows +import nodes + + +class ContextWindowsManualNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ContextWindowsManual", + display_name="Context Windows (Manual)", + category="context", + description="Manually set context windows.", + inputs=[ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + ], + outputs=[ + io.Model.Output(tooltip="The model with context windows applied during sampling."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + model = model.clone() + model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( + context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), + fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method), + context_length=context_length, + context_overlap=context_overlap, + context_stride=context_stride, + closed_loop=closed_loop, + dim=dim) + # make memory usage calculation only take into account the context window latents + comfy.context_windows.create_prepare_sampling_wrapper(model) + return io.NodeOutput(model) + +class WanContextWindowsManualNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "WanContextWindowsManual" + schema.display_name = "WAN Context Windows (Manual)" + schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + + +class ContextWindowsExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ContextWindowsManualNode, + WanContextWindowsManualNode, + ] + +def comfy_entrypoint(): + return ContextWindowsExtension() diff --git a/nodes.py b/nodes.py index 9448f9c1b..860a236aa 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_context_windows.py", ] import_failed = [] From 72fd4d22b6a4fa11a3f737c9a633e7d635a42181 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 13:03:21 -0700 Subject: [PATCH 055/325] av is an essential dependency. (#9341) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bfb31a73f..551002b5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,11 +20,11 @@ tqdm psutil alembic SQLAlchemy +av>=14.2.0 #non essential dependencies: kornia>=0.7.1 spandrel soundfile -av>=14.2.0 pydantic~=2.0 pydantic-settings~=2.0 From 644b23ac0b92442b475e44397c62aa8de929d546 Mon Sep 17 00:00:00 2001 From: filtered <176114999+webfiltered@users.noreply.github.com> Date: Fri, 15 Aug 2025 07:36:53 +1000 Subject: [PATCH 056/325] Make custom node testing checkbox optional in issue templates (#9342) The checkbox for confirming custom node testing is now optional in both bug report and user support templates. This allows users to submit issues even if they haven't been able to test with custom nodes disabled, making the reporting process more accessible. --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- .github/ISSUE_TEMPLATE/user-support.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 69ce998eb..3cf2717b7 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -22,7 +22,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index 50657d493..281661f92 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -18,7 +18,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Your question From fa570cbf599657e73c636872616c0b1f8e74f692 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:44:22 -0700 Subject: [PATCH 057/325] Update CODEOWNERS (#9343) --- CODEOWNERS | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index c4acbf06e..c8acd66d5 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,21 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill # Python web server -/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne +/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill +/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill +/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill # Node developers -/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne -/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill From deebee4ff6fd1b2713683d22e5e2e07170daa867 Mon Sep 17 00:00:00 2001 From: guill Date: Thu, 14 Aug 2025 18:46:55 -0700 Subject: [PATCH 058/325] Update default parameters for Moonvalley video nodes (#9290) * Update default parameters for Moonvalley video nodes - Changed default negative prompts to a more extensive list for both BaseMoonvalleyVideoNode and MoonvalleyVideo2VideoNode. - Updated default guidance scale values for both nodes to enhance prompt adherence. - Set a fixed default seed value for consistency in video generation. * no message * ruff fix --------- Co-authored-by: thorsten --- comfy_api_nodes/nodes_moonvalley.py | 128 ++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 37 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 164ca3ea5..806a70e06 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,6 +1,5 @@ import logging from typing import Any, Callable, Optional, TypeVar -import random import torch from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, @@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: def _validate_video_dimensions(width: int, height: int) -> None: """Validates video dimensions meet Moonvalley V2V requirements.""" supported_resolutions = { - (1920, 1080), (1080, 1920), (1152, 1152), - (1536, 1152), (1152, 1536) + (1920, 1080), + (1080, 1920), + (1152, 1152), + (1536, 1152), + (1152, 1536), } if (width, height) not in supported_resolutions: - supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") + supported_list = ", ".join( + [f"{w}x{h}" for w, h in sorted(supported_resolutions)] + ) + raise ValueError( + f"Resolution {width}x{height} not supported. Supported: {supported_list}" + ) def _validate_container_format(video: VideoInput) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() - if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError( + f"Only MP4 container format supported. Got: {container_format}" + ) def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video - def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, @@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate estimated_frames = int(duration_sec * fps) - target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + target_frames = ( + estimated_frames // 16 + ) * 16 # Round down to nearest multiple of 16 if target_frames == 0: raise ValueError("Video too short: need at least 16 frames for Moonvalley") @@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode: MoonvalleyTextToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts", + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", ), "resolution": ( IO.COMBO, @@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode: "tooltip": "Resolution of the output video", }, ), - # "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}), "prompt_adherence": model_field_to_node_input( IO.FLOAT, MoonvalleyTextToVideoInferenceParams, "guidance_scale", - default=7.0, + default=10.0, step=1, min=1, max=20, @@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode: IO.INT, MoonvalleyTextToVideoInferenceParams, "seed", - default=random.randint(0, 2**32 - 1), + default=9, min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", - control_after_generate=True, ), "steps": model_field_to_node_input( IO.INT, @@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = (await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - ))[0] + image_url = ( + await upload_images_to_comfyapi( + image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type + ) + )[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): return { "required": { "prompt": model_field_to_node_input( - IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", - multiline=True + IO.STRING, + MoonvalleyVideoToVideoRequest, + "prompt_text", + multiline=True, ), "negative_prompt": model_field_to_node_input( IO.STRING, MoonvalleyVideoToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + ), + "seed": model_field_to_node_input( + IO.INT, + MoonvalleyVideoToVideoInferenceParams, + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display="number", + tooltip="Random seed value", + control_after_generate=False, + ), + "prompt_adherence": model_field_to_node_input( + IO.FLOAT, + MoonvalleyVideoToVideoInferenceParams, + "guidance_scale", + default=10.0, + step=1, + min=1, + max=20, ), - "seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", @@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "unique_id": "UNIQUE_ID", }, "optional": { - "video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), + "video": ( + IO.VIDEO, + { + "default": "", + "multiline": False, + "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + }, + ), "control_type": ( ["Motion Transfer", "Pose Transfer"], {"default": "Motion Transfer"}, @@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "max": 100, "tooltip": "Only used if control_type is 'Motion Transfer'", }, - ) - } + ), + "image": model_field_to_node_input( + IO.IMAGE, + MoonvalleyTextToVideoRequest, + "image_url", + tooltip="The reference image used to generate the video", + ), + }, } RETURN_TYPES = ("VIDEO",) @@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") + image = kwargs.get("image", None) if not video: raise MoonvalleyApiError("video is required") @@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi( + validated_video, auth_kwargs=kwargs + ) + mime_type = "image/png" + if not image is None: + validate_input_image(image, with_frame_conditioning=True) + image_url = await upload_images_to_comfyapi( + image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type + ) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): # Only include motion_intensity for Motion Transfer control_params = {} if control_type == "Motion Transfer" and motion_intensity is not None: - control_params['motion_intensity'] = motion_intensity + control_params["motion_intensity"] = motion_intensity - inference_params=MoonvalleyVideoToVideoInferenceParams( + inference_params = MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, seed=kwargs.get("seed"), - control_params=control_params + control_params=control_params, ) control = self.parseControlParameter(control_type) @@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): prompt_text=prompt, inference_params=inference_params, ) + request.image_url = image_url if not image is None else None initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -694,15 +748,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) - inference_params=MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), - ) + inference_params = MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=kwargs.get("steps"), + seed=kwargs.get("seed"), + guidance_scale=kwargs.get("prompt_adherence"), + num_frames=128, + width=width_height.get("width"), + height=width_height.get("height"), + ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) From 5d65d6753b195d674ce16522d6c34f9a33f36269 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 04:48:41 +0300 Subject: [PATCH 059/325] convert WAN nodes to V3 schema (#9201) --- comfy_extras/nodes_wan.py | 549 +++++++++++++++++++++----------------- 1 file changed, 298 insertions(+), 251 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index f80c83ba6..694a183f6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -9,29 +9,35 @@ import comfy.clip_vision import json import numpy as np from typing import Tuple +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class WanImageToVideo: +class WanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -51,32 +57,36 @@ class WanImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunControlToVideo: +class WanFunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -101,31 +111,34 @@ class WanFunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22FunControlToVideo: +class Wan22FunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"ref_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - # "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22FunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -158,32 +171,36 @@ class Wan22FunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFirstLastFrameToVideo: +class WanFirstLastFrameToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), - "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFirstLastFrameToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -224,62 +241,70 @@ class WanFirstLastFrameToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunInpaintToVideo: +class WanFunInpaintToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunInpaintToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput: flfv = WanFirstLastFrameToVideo() - return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) -class WanVaceToVideo: +class WanVaceToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }, - "optional": {"control_video": ("IMAGE", ), - "control_masks": ("MASK", ), - "reference_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanVaceToVideo", + category="conditioning/video_models", + is_experimental=True, + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01), + io.Image.Input("control_video", optional=True), + io.Mask.Input("control_masks", optional=True), + io.Image.Input("reference_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") - RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - EXPERIMENTAL = True - - def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput: latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -336,52 +361,59 @@ class WanVaceToVideo: latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent) -class TrimVideoLatent: +class TrimVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), - }} + def define_schema(cls): + return io.Schema( + node_id="TrimVideoLatent", + category="latent/video", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.Int.Input("trim_amount", default=0, min=0, max=99999), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/video" - - EXPERIMENTAL = True - - def op(self, samples, trim_amount): + @classmethod + def execute(cls, samples, trim_amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1[:, :, trim_amount:] - return (samples_out,) + return io.NodeOutput(samples_out) -class WanCameraImageToVideo: +class WanCameraImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanCameraImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.WanCameraEmbedding.Input("camera_conditions", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -404,29 +436,34 @@ class WanCameraImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanPhantomSubjectToVideo: +class WanPhantomSubjectToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"images": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("images", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative_text"), + io.Conditioning.Output(display_name="negative_img_text"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, images): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond2 = negative if images is not None: @@ -442,7 +479,7 @@ class WanPhantomSubjectToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, cond2, negative, out_latent) + return io.NodeOutput(positive, cond2, negative, out_latent) def parse_json_tracks(tracks): """Parse JSON track data into a standardized format""" @@ -655,39 +692,41 @@ def patch_motion( return out_mask_full, out_feature_full -class WanTrackToVideo: +class WanTrackToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "tracks": ("STRING", {"multiline": True, "default": "[]"}), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "topk": ("INT", {"default": 2, "min": 1, "max": 10}), - "start_image": ("IMAGE", ), - }, - "optional": { - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.String.Input("tracks", multiline=True, default="[]"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input("topk", default=2, min=1, max=10), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, - temperature, topk, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput: tracks_data = parse_json_tracks(tracks) if not tracks_data: - return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) @@ -741,34 +780,36 @@ class WanTrackToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22ImageToVideoLatent: +class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Latent.Output(), + ], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) if start_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -783,19 +824,25 @@ class Wan22ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "WanTrackToVideo": WanTrackToVideo, - "WanImageToVideo": WanImageToVideo, - "WanFunControlToVideo": WanFunControlToVideo, - "Wan22FunControlToVideo": Wan22FunControlToVideo, - "WanFunInpaintToVideo": WanFunInpaintToVideo, - "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, - "WanVaceToVideo": WanVaceToVideo, - "TrimVideoLatent": TrimVideoLatent, - "WanCameraImageToVideo": WanCameraImageToVideo, - "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, - "Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, -} +class WanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanTrackToVideo, + WanImageToVideo, + WanFunControlToVideo, + Wan22FunControlToVideo, + WanFunInpaintToVideo, + WanFirstLastFrameToVideo, + WanVaceToVideo, + TrimVideoLatent, + WanCameraImageToVideo, + WanPhantomSubjectToVideo, + Wan22ImageToVideoLatent, + ] + +async def comfy_entrypoint() -> WanExtension: + return WanExtension() From ad19a069f68a19566632b9bda3e72f4eed8a22d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:16:01 -0700 Subject: [PATCH 060/325] Make SLG nodes work on Qwen Image model. (#9345) --- comfy/ldm/qwen_image/model.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index c15ab8e40..99843f88d 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -356,6 +356,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + transformer_options={}, **kwargs ): timestep = timesteps @@ -383,14 +384,26 @@ class QwenImageTransformer2DModel(nn.Module): else self.time_text_embed(timestep, guidance, hidden_states) ) - for block in self.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From f0d5d0111f1f78bc8ce5d1f3968f19e40cd2ce7b Mon Sep 17 00:00:00 2001 From: "Xiangxi Guo (Ryan)" Date: Thu, 14 Aug 2025 20:41:37 -0700 Subject: [PATCH 061/325] Avoid torch compile graphbreak for older pytorch versions (#9344) Turns out torch.compile has some gaps in context manager decorator syntax support. I've sent patches to fix that in PyTorch, but it won't be available for all the folks running older versions of PyTorch, hence this trivial patch. --- comfy/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index be312d714..2be35f76a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -41,9 +41,11 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + # Use this (rather than the decorator syntax) to eliminate graph + # break for pytorch < 2.9 + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From 4e5c230f6a957962961794c07f02be748076c771 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:44:02 -0700 Subject: [PATCH 062/325] Fix last commit not working on older pytorch. (#9346) --- comfy/ops.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2be35f76a..18e7db705 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel + import inspect + if "set_priority" in inspect.signature(sdpa_kernel).parameters: + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] - SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - # Use this (rather than the decorator syntax) to eliminate graph - # break for pytorch < 2.9 - with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + else: + logging.warning("Torch version too old to set sdpa backend priority.") except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From e08ecfbd8a9deda8939b14d7f1ff7d7139f1a4ed Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 21:22:26 -0700 Subject: [PATCH 063/325] Add warning when using old pytorch. (#9347) --- comfy/rmsnorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 66ae8321d..555542a46 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import numbers +import logging RMSNorm = None @@ -9,6 +10,7 @@ try: RMSNorm = torch.nn.RMSNorm except: rms_norm_torch = None + logging.warning("Please update pytorch to use native RMSNorm") def rms_norm(x, weight=None, eps=1e-6): From 027c63f63a7f5f380a4df1057c548410b0a87606 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 21:57:47 +0300 Subject: [PATCH 064/325] fix(OpenAIGPTImage1): set correct MIME type for multipart uploads to OpenAI edits (#9348) --- comfy_api_nodes/nodes_openai.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index ab3c5363b..cbff2b2d2 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC): path = "/proxy/openai/images/generations" content_type = "application/json" request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None files = [] if image is not None: @@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC): img_byte_arr = io.BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - img_binaries.append(img_binary) if batch_size == 1: - files.append(("image", img_binary)) + files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) else: - files.append(("image[]", img_binary)) + files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) if mask is not None: if image is None: @@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_img_byte_arr = io.BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation operation = SynchronousOperation( From c308a8840aebf06649364e8e175862250a2d8823 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:50:39 -0700 Subject: [PATCH 065/325] Add FluxKontextMultiReferenceLatentMethod node. (#9356) This node is only useful if someone trains the kontext model to properly use multiple reference images via the index method. The default is the offset method which feeds the multiple images like if they were stitched together as one. This method works with the current flux kontext model. --- comfy/ldm/flux/model.py | 24 ++++++++++++++++-------- comfy/model_base.py | 4 ++++ comfy_extras/nodes_flux.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8f4d99f54..c4de82795 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -224,19 +224,27 @@ class Flux(nn.Module): if ref_latents is not None: h = 0 w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" for ref in ref_latents: - h_offset = 0 - w_offset = 0 - if ref.shape[-2] + h > ref.shape[-1] + w: - w_offset = w + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 else: - h_offset = h + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - h = max(h, ref.shape[-2] + h_offset) - w = max(w, ref.shape[-1] + w_offset) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) diff --git a/comfy/model_base.py b/comfy/model_base.py index cde61df7c..bf874b875 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -890,6 +890,10 @@ class Flux(BaseModel): for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 8a8a17698..c8db75bb3 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -100,9 +100,28 @@ class FluxKontextImageScale: return (image, ) +class FluxKontextMultiReferenceLatentMethod: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + "reference_latents_method": (("offset", "index"), ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + EXPERIMENTAL = True + + CATEGORY = "advanced/conditioning/flux" + + def append(self, conditioning, reference_latents_method): + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return (c, ) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, "FluxDisableGuidance": FluxDisableGuidance, "FluxKontextImageScale": FluxKontextImageScale, + "FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, } From 1702e6df16b0a52e147f19e3d5c5548c25a64339 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:29:58 -0700 Subject: [PATCH 066/325] Implement wan2.2 camera model. (#9357) Use the old WanCameraImageToVideo node. --- comfy/ldm/wan/model.py | 7 ++++++- comfy/model_detection.py | 5 ++++- comfy/supported_models.py | 14 +++++++++++++- comfy_extras/nodes_wan.py | 7 +++++-- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4e2d99566..9d3741be3 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -768,7 +768,12 @@ class CameraWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + if model_type == 'camera': + model_type = 'i2v' + else: + model_type = 't2v' + + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8acc51e20..2bec0541e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "camera" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" + else: + dit_config["model_type"] = "camera_2.2" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 156ff9e26..7ed6dfd69 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V): def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Camera(self, image_to_video=False, device=device) return out + +class WAN22_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "camera_2.2", + "in_dim": 36, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out + class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 694a183f6..83a990688 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -422,9 +422,12 @@ class WanCameraImageToVideo(io.ComfyNode): start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + mask[:, :, :start_image.shape[0] + 3] = 0.0 + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask}) if camera_conditions is not None: positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) From ed2e33c69a291094c4fcc13d8426c49844a6363c Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 15 Aug 2025 20:32:58 -0700 Subject: [PATCH 067/325] bump frontend version to 1.25.8 (#9361) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 551002b5b..2ae44ebe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.25.8 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From 20a84166d0d37dd6833caa6cadf3bfac8c241b48 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Sat, 16 Aug 2025 02:07:12 -0400 Subject: [PATCH 068/325] record audio node (#8716) * record audio node * sf --- comfy_extras/nodes_audio.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a90b31779..3b23f65d8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -346,6 +346,24 @@ class LoadAudio: return "Invalid audio file: {}".format(audio) return True +class RecordAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": {"audio": ("AUDIO_RECORD", {})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, @@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = { "LoadAudio": LoadAudio, "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, + "RecordAudio": RecordAudio, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudio": "Save Audio (FLAC)", "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", + "RecordAudio": "Record Audio", } From 0f2b8525bcafe213e8421a49564a90f926e81f2e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 16 Aug 2025 14:51:28 -0700 Subject: [PATCH 069/325] Qwen image model refactor. (#9375) --- comfy/ldm/qwen_image/model.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 99843f88d..40d8fd979 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - def pos_embeds(self, x, context): + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) - txt_start = round(max(h_len, w_len)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( self, @@ -363,13 +367,13 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask - image_rotary_emb = self.pos_embeds(x, context) + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] - hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -408,6 +412,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] From ed43784b0d04e5b8e8ff2c057fa84b9c5132aaf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 17 Aug 2025 13:45:39 -0700 Subject: [PATCH 070/325] WIP Qwen edit model: The diffusion model part. (#9383) --- comfy/ldm/qwen_image/model.py | 26 ++++++++++++++++++++++++++ comfy/model_base.py | 10 ++++++++++ 2 files changed, 36 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 40d8fd979..a3c726299 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + ref_latents=None, transformer_options={}, **kwargs ): @@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) diff --git a/comfy/model_base.py b/comfy/model_base.py index bf874b875..15bd7abef 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1331,4 +1331,14 @@ class QwenImage(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out From d4e353a94ec5a8cb15ed151990a9518b890e5d4f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 18 Aug 2025 05:38:40 +0800 Subject: [PATCH 071/325] Update template to 0.1.60 (#9377) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2ae44ebe1..72a700028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.8 -comfyui-workflow-templates==0.1.59 +comfyui-workflow-templates==0.1.60 comfyui-embedded-docs==0.2.6 torch torchsde From 7f3b9b16c6636cb1201213574892d33c2a35e4ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Aug 2025 15:54:07 -0700 Subject: [PATCH 072/325] Make step index detection much more robust (#9392) --- comfy/context_windows.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 928b111df..041f380f9 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) - self._step = int(indexes[0]) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + matches = torch.nonzero(mask) + if torch.numel(matches) == 0: + raise Exception("No sample_sigmas matched current timestep; something went wrong.") + self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model From da2efeaec6609265051165bfb413a2a4c84cf4bb Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 17 Aug 2025 20:21:02 -0700 Subject: [PATCH 073/325] Bump frontend to 1.25.9 (#9394) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 72a700028..c7a5c47ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.8 +comfyui-frontend-package==1.25.9 comfyui-workflow-templates==0.1.60 comfyui-embedded-docs==0.2.6 torch From bd2ab73976a4e245db3e057795328c89bfd98a88 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:26:55 +0300 Subject: [PATCH 074/325] fix(WAN-nodes): invalid nodeid for WanTrackToVideo (#9396) --- comfy_extras/nodes_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 83a990688..0fff02f76 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -699,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="WanPhantomSubjectToVideo", + node_id="WanTrackToVideo", category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), From 4977f203fa8e9e3ab22884c8ace8f9b540d48952 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 18 Aug 2025 19:38:34 -0700 Subject: [PATCH 075/325] P2 of qwen edit model. (#9412) * P2 of qwen edit model. * Typo. * Fix normal qwen. * Fix. * Make the TextEncodeQwenImageEdit also set the ref latent. If you don't want it to set the ref latent and want to use the ReferenceLatent node with your custom latent instead just disconnect the VAE. --- comfy/clip_model.py | 2 +- comfy/model_base.py | 8 + comfy/sd1_clip.py | 11 +- comfy/text_encoders/bert.py | 2 +- comfy/text_encoders/llama.py | 43 ++- comfy/text_encoders/qwen_image.py | 20 +- comfy/text_encoders/qwen_vl.py | 428 ++++++++++++++++++++++++++++++ comfy/text_encoders/t5.py | 2 +- comfy_extras/nodes_qwen.py | 63 +++++ nodes.py | 1 + 10 files changed, 565 insertions(+), 15 deletions(-) create mode 100644 comfy/text_encoders/qwen_vl.py create mode 100644 comfy_extras/nodes_qwen.py diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c8294d483..7e47d8a55 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]): if embeds is not None: x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) else: diff --git a/comfy/model_base.py b/comfy/model_base.py index 15bd7abef..6c861b15e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1325,6 +1325,7 @@ class Omnigen2(BaseModel): class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1342,3 +1343,10 @@ class QwenImage(BaseModel): if ref_latents_method is not None: out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ade340fd1..1e8adbe69 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) index = 0 pad_extra = 0 + embeds_info = [] for o in other_embeds: emb = o[1] if torch.is_tensor(emb): emb = {"type": "embedding", "data": emb} + extra = None emb_type = emb.get("type", None) if emb_type == "embedding": emb = emb.get("data", None) else: if hasattr(self.transformer, "preprocess_embed"): - emb = self.transformer.preprocess_embed(emb, device=device) + emb, extra = self.transformer.preprocess_embed(emb, device=device) else: emb = None @@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] index += emb_shape - 1 + embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra}) else: index += -1 pad_extra += emb_shape @@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_masks.append(attention_mask) num_tokens.append(sum(attention_mask)) - return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): device = self.transformer.get_input_embeddings().weight.device - embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: @@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: intermediate_output = self.layer_idx - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index 551b03162..ed4638a9a 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 1da6a0c94..9d90d5a61 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -2,12 +2,14 @@ import torch import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any +import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit import comfy.model_management +from . import qwen_vl @dataclass class Llama2Config: @@ -100,12 +102,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, seq_len, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) - position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0) - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -277,7 +277,7 @@ class Llama2_(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): if embeds is not None: x = embeds else: @@ -286,8 +286,11 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + if position_ids is None: + position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, - x.shape[1], + position_ids, self.config.rope_theta, device=x.device) @@ -372,8 +375,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): self.num_layers = config.num_hidden_layers self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) + return self.visual(image.to(device, dtype=torch.float32), grid), grid + return None, None + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + grid = None + for e in embeds_info: + if e.get("type") == "image": + grid = e.get("extra", None) + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + start = e.get("index") + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) + position_ids[0, start:end] = start + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + + if grid is None: + position_ids = None + + return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index ce5c98097..f07318d6c 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): if llama_template is None: - llama_text = self.llama_template.format(text) + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == 151655: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens class Qwen25_7BVLIModel(sd1_clip.SDClipModel): diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py new file mode 100644 index 000000000..3b18ce730 --- /dev/null +++ b/comfy/text_encoders/qwen_vl.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +import math +from comfy.ldm.modules.attention import optimized_attention_for_device + + +def process_qwen2vl_images( + images: torch.Tensor, + min_pixels: int = 3136, + max_pixels: int = 12845056, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: list = None, + image_std: list = None, +): + if image_mean is None: + image_mean = [0.48145466, 0.4578275, 0.40821073] + if image_std is None: + image_std = [0.26862954, 0.26130258, 0.27577711] + + batch_size, height, width, channels = images.shape + device = images.device + # dtype = images.dtype + + images = images.permute(0, 3, 1, 2) + + grid_thw_list = [] + img = images[0] + + factor = patch_size * merge_size + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), + size=(h_bar, w_bar), + mode='bilinear', + align_corners=False + ).squeeze(0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) + + pixel_values = normalized + grid_thw_list.append(grid_thw) + image_grid_thw = torch.stack(grid_thw_list) + + grid_t = 1 + channel = pixel_values.shape[0] + pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) + + patches = pixel_values.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, image_grid_thw + + +class VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 3584, + device=None, + dtype=None, + ops=None, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = ops.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + device=device, + dtype=dtype + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int, device) -> torch.Tensor: + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), + nn.GELU(), + ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x).reshape(-1, self.hidden_size) + x = self.mlp(x) + return x + + +class VisionAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim ** -0.5 + + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + if hidden_states.dim() == 2: + seq_length, _ = hidden_states.shape + batch_size = 1 + hidden_states = hidden_states.unsqueeze(0) + else: + batch_size, seq_length, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) + query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + + return attn_output + + +class VisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): + super().__init__() + self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + self.act_fn = nn.SiLU() + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class VisionBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2VLVisionTransformer(nn.Module): + def __init__( + self, + hidden_size: int = 3584, + output_hidden_size: int = 3584, + intermediate_size: int = 3420, + num_heads: int = 16, + num_layers: int = 32, + patch_size: int = 14, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + window_size: int = 112, + device=None, + dtype=None, + ops=None + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.window_size = window_size + self.fullatt_block_indexes = [7, 15, 23, 31] + + self.patch_embed = VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + device=device, + dtype=dtype, + ops=ops, + ) + + head_dim = hidden_size // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) + for _ in range(num_layers) + ]) + + self.merger = PatchMerger( + dim=output_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + device=device, + dtype=dtype, + ops=ops, + ) + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def get_position_embeddings(self, grid_thw, device): + pos_ids = [] + + for t, h, w in grid_thw: + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) + return rotary_pos_emb_full[pos_ids].flatten(1) + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + + hidden_states = self.patch_embed(pixel_values) + + window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) + + seq_len, _ = hidden_states.size() + spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + position_embeddings = position_embeddings[window_index, :, :] + position_embeddings = position_embeddings.reshape(seq_len, -1) + position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) + position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) + + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for i, block in enumerate(self.blocks): + if i in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) + + hidden_states = self.merger(hidden_states) + return hidden_states diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 36bf35309..e8588992a 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module): self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py new file mode 100644 index 000000000..b5088fae2 --- /dev/null +++ b/comfy_extras/nodes_qwen.py @@ -0,0 +1,63 @@ +import node_helpers +import comfy.utils + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +class TextEncodeQwenImageEdit: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }, + "optional": {"vae": ("VAE", ), + "image": ("IMAGE", ),}} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, prompt, vae=None, image=None): + ref_latent = None + if image is None: + images = [] + else: + images = [image] + if vae is not None: + width = image.shape[2] + height = image.shape[1] + aspect_ratio = width / height + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS) + image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + ref_latent = vae.encode(image[:, :, :, :3]) + + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if ref_latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + return (conditioning, ) + + +NODE_CLASS_MAPPINGS = { + "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, +} diff --git a/nodes.py b/nodes.py index 860a236aa..b3fa9c51a 100644 --- a/nodes.py +++ b/nodes.py @@ -2321,6 +2321,7 @@ async def init_builtin_extra_nodes(): "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py", + "nodes_qwen.py", ] import_failed = [] From 36b5127fd3eee8eaf95ff7296a61269ed56d53c0 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:28:07 +0300 Subject: [PATCH 076/325] api_nodes: add kling-v2-1 and v2-1-master (#9257) --- comfy_api_nodes/apis/__init__.py | 3 +++ comfy_api_nodes/nodes_kling.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 54298e8a9..c6f91e9d6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum): class KlingTextToVideoModelName(str, Enum): kling_v1 = 'kling-v1' kling_v1_6 = 'kling-v1-6' + kling_v2_1_master = 'kling-v2-1-master' class KlingVideoGenAspectRatio(str, Enum): @@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum): kling_v1_5 = 'kling-v1-5' kling_v1_6 = 'kling-v1-6' kling_v2_master = 'kling-v2-master' + kling_v2_1 = 'kling-v2-1' + kling_v2_1_master = 'kling-v2-1-master' class KlingVideoResult(BaseModel): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9d483bb0e..9fa390985 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase): "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), + "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), + "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), } @classmethod From f16a70ba670e11de549af188663a87c77c5bc0c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:28:27 +0300 Subject: [PATCH 077/325] api_nodes: add MinimaxHailuoVideoNode node (#9262) --- comfy_api_nodes/apis/__init__.py | 13 ++- comfy_api_nodes/nodes_minimax.py | 185 ++++++++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index c6f91e9d6..7a09df55b 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1623,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel): task_id: str = Field(..., description='The task ID being queried.') -class Model(str, Enum): +class MiniMaxModel(str, Enum): T2V_01_Director = 'T2V-01-Director' I2V_01_Director = 'I2V-01-Director' S2V_01 = 'S2V-01' I2V_01 = 'I2V-01' I2V_01_live = 'I2V-01-live' T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' class SubjectReferenceItem(BaseModel): @@ -1651,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', ) - model: Model = Field( + model: MiniMaxModel = Field( ..., description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', ) @@ -1668,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) class MinimaxVideoGenerationResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 58d2ed90c..bb3c9e710 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,3 +1,4 @@ +from inspect import cleandoc from typing import Union import logging import torch @@ -10,7 +11,7 @@ from comfy_api_nodes.apis import ( MinimaxFileRetrieveResponse, MinimaxTaskResultResponse, SubjectReferenceItem, - Model + MiniMaxModel ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -84,7 +85,6 @@ class MinimaxTextToVideoNode: FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True async def generate_video( self, @@ -121,7 +121,7 @@ class MinimaxTextToVideoNode: response_model=MinimaxVideoGenerationResponse, ), request=MinimaxVideoGenerationRequest( - model=Model(model), + model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, first_frame_image=image_url, @@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): @@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True + + +class MinimaxHailuoVideoNode: + """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt_text": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to guide the video generation.", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + "first_frame_image": ( + IO.IMAGE, + { + "tooltip": "Optional image to use as the first frame to generate a video." + }, + ), + "prompt_optimizer": ( + IO.BOOLEAN, + { + "tooltip": "Optimize prompt to improve generation quality when needed.", + "default": True, + }, + ), + "duration": ( + IO.COMBO, + { + "tooltip": "The length of the output video in seconds.", + "default": 6, + "options": [6, 10], + }, + ), + "resolution": ( + IO.COMBO, + { + "tooltip": "The dimensions of the video display. " + "1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.", + "default": "768P", + "options": ["768P", "1080P"], + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("VIDEO",) + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "generate_video" + CATEGORY = "api node/video/MiniMax" + API_NODE = True + + async def generate_video( + self, + prompt_text, + seed=0, + first_frame_image: torch.Tensor=None, # used for ImageToVideo + prompt_optimizer=True, + duration=6, + resolution="768P", + model="MiniMax-Hailuo-02", + unique_id: Union[str, None]=None, + **kwargs, + ): + if first_frame_image is None: + validate_string(prompt_text, field_name="prompt_text") + + if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6: + raise Exception( + "When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds." + ) + + # upload image, if passed in + image_url = None + if first_frame_image is not None: + image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0] + + video_generate_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/video_generation", + method=HttpMethod.POST, + request_model=MinimaxVideoGenerationRequest, + response_model=MinimaxVideoGenerationResponse, + ), + request=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + prompt_optimizer=prompt_optimizer, + duration=duration, + resolution=resolution, + ), + auth_kwargs=kwargs, + ) + response = await video_generate_operation.execute() + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + average_duration = 120 if resolution == "768P" else 240 + video_generate_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/minimax/query/video_generation", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxTaskResultResponse, + query_params={"task_id": task_id}, + ), + completed_statuses=["Success"], + failed_statuses=["Fail"], + status_extractor=lambda x: x.status.value, + estimated_duration=average_duration, + node_id=unique_id, + auth_kwargs=kwargs, + ) + task_result = await video_generate_operation.execute() + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_retrieve_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/files/retrieve", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxFileRetrieveResponse, + query_params={"file_id": int(file_id)}, + ), + request=EmptyRequest(), + auth_kwargs=kwargs, + ) + file_result = await file_retrieve_operation.execute() + + file_url = file_result.file.download_url + if file_url is None: + raise Exception( + f"No video was found in the response. Full response: {file_result.model_dump()}" + ) + logging.info(f"Generated video URL: {file_url}") + if unique_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, unique_id) + + video_io = await download_url_to_bytesio(file_url) + if video_io is None: + error_msg = f"Failed to download video from {file_url}" + logging.error(error_msg) + raise Exception(error_msg) + return (VideoFromFile(video_io),) # A dictionary that contains all nodes you want to export with their names @@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = { "MinimaxTextToVideoNode": MinimaxTextToVideoNode, "MinimaxImageToVideoNode": MinimaxImageToVideoNode, # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, + "MinimaxHailuoVideoNode": MinimaxHailuoVideoNode, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MinimaxTextToVideoNode": "MiniMax Text to Video", "MinimaxImageToVideoNode": "MiniMax Image to Video", "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", + "MinimaxHailuoVideoNode": "MiniMax Hailuo Video", } From 07a927517cfaf099fec3903e8973f758e62d65f9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:29:01 +0300 Subject: [PATCH 078/325] api_nodes: add GPT-5 series models (#9325) --- comfy_api_nodes/nodes_openai.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index cbff2b2d2..674c9ede0 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum): gpt_4_1 = "gpt-4.1" gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_nano = "gpt-4.1-nano" + gpt_5 = "gpt-5" + gpt_5_mini = "gpt-5-mini" + gpt_5_nano = "gpt-5-nano" class OpenAIDalle2(ComfyNodeABC): From d844d8b13bfd6a83b0a7d0491aa2978ac44a6158 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:29:24 +0300 Subject: [PATCH 079/325] api_nodes: added release version of google's models (#9304) --- comfy_api_nodes/nodes_gemini.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3751fb2a1..ba4167a50 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -46,6 +46,8 @@ class GeminiModel(str, Enum): gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" + gemini_2_5_pro = "gemini-2.5-pro" + gemini_2_5_flash = "gemini-2.5-flash" def get_gemini_endpoint( @@ -97,7 +99,7 @@ class GeminiNode(ComfyNodeABC): { "tooltip": "The Gemini model to use for generating responses.", "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro_preview_05_06.value, + "default": GeminiModel.gemini_2_5_pro.value, }, ), "seed": ( From 54d8fdbed0a7b171ab8cfb02e29a7e0dc5fe78fd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:30:06 +0300 Subject: [PATCH 080/325] feat(api-nodes): add Vidu Video nodes (#9368) --- comfy_api_nodes/nodes_vidu.py | 622 +++++++++++++++++++++++ comfy_api_nodes/util/validation_utils.py | 53 ++ nodes.py | 1 + 3 files changed, 676 insertions(+) create mode 100644 comfy_api_nodes/nodes_vidu.py diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py new file mode 100644 index 000000000..2f441948c --- /dev/null +++ b/comfy_api_nodes/nodes_vidu.py @@ -0,0 +1,622 @@ +import logging +from enum import Enum +from typing import Any, Callable, Optional, Literal, TypeVar +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import ( + validate_aspect_ratio_closeness, + validate_image_dimensions, + validate_image_aspect_ratio_range, + get_number_of_images, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi + + +VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" +VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" +VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" +VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" +VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" + +R = TypeVar("R") + +class VideoModelName(str, Enum): + vidu_q1 = 'viduq1' + + +class AspectRatio(str, Enum): + r_16_9 = "16:9" + r_9_16 = "9:16" + r_1_1 = "1:1" + + +class Resolution(str, Enum): + r_1080p = "1080p" + + +class MovementAmplitude(str, Enum): + auto = "auto" + small = "small" + medium = "medium" + large = "large" + + +class TaskCreationRequest(BaseModel): + model: VideoModelName = VideoModelName.vidu_q1 + prompt: Optional[str] = Field(None, max_length=1500) + duration: Optional[Literal[5]] = 5 + seed: Optional[int] = Field(0, ge=0, le=2147483647) + aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 + resolution: Optional[Resolution] = Resolution.r_1080p + movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto + images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") + + +class TaskStatus(str, Enum): + created = "created" + queueing = "queueing" + processing = "processing" + success = "success" + failed = "failed" + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: TaskStatus = Field(...) + created_at: str = Field(...) + code: Optional[int] = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: TaskStatus = Field(...) + err_code: Optional[str] = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") + + +async def poll_until_finished( + auth_kwargs: dict[str, str], + api_endpoint: ApiEndpoint[Any, R], + result_url_extractor: Optional[Callable[[R], str]] = None, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> R: + return await PollingOperation( + poll_endpoint=api_endpoint, + completed_statuses=[TaskStatus.success.value], + failed_statuses=[TaskStatus.failed.value], + status_extractor=lambda response: response.state.value, + auth_kwargs=auth_kwargs, + result_url_extractor=result_url_extractor, + estimated_duration=estimated_duration, + node_id=node_id, + poll_interval=16.0, + max_poll_attempts=256, + ).execute() + + +def get_video_url_from_response(response) -> Optional[str]: + if response.creations: + return response.creations[0].url + return None + + +def get_video_from_response(response) -> TaskResult: + if not response.creations: + error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) + return response.creations[0] + + +async def execute_task( + vidu_endpoint: str, + auth_kwargs: Optional[dict[str, str]], + payload: TaskCreationRequest, + estimated_duration: int, + node_id: str, +) -> R: + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=vidu_endpoint, + method=HttpMethod.POST, + request_model=TaskCreationRequest, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + if response.state == TaskStatus.failed: + error_msg = f"Vidu request failed. Code: {response.code}" + logging.error(error_msg) + raise RuntimeError(error_msg) + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=VIDU_GET_GENERATION_STATUS % response.task_id, + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=estimated_duration, + node_id=node_id, + ) + + +class ViduTextToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduTextToVideoNode", + display_name="Vidu Text To Video Generation", + category="api node/video/Vidu", + description="Generate video from text prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[model.value for model in AspectRatio], + default=AspectRatio.r_16_9.value, + tooltip="The aspect ratio of the output video", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduImageToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduImageToVideoNode", + display_name="Vidu Image To Video Generation", + category="api node/video/Vidu", + description="Generate video from image and optional prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="A textual description for video generation", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = await upload_images_to_comfyapi( + image, + max_images=1, + mime_type="image/png", + auth_kwargs=auth, + ) + results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduReferenceVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduReferenceVideoNode", + display_name="Vidu Reference To Video Generation", + category="api node/video/Vidu", + description="Generate video from multiple images and prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "images", + tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[model.value for model in AspectRatio], + default=AspectRatio.r_16_9.value, + tooltip="The aspect ratio of the output video", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + images: torch.Tensor, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + a = get_number_of_images(images) + if a > 7: + raise ValueError("Too many images, maximum allowed is 7.") + for image in images: + validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_dimensions(image, min_width=128, min_height=128) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = await upload_images_to_comfyapi( + images, + max_images=7, + mime_type="image/png", + auth_kwargs=auth, + ) + results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduStartEndToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduStartEndToVideoNode", + display_name="Vidu Start End To Video Generation", + category="api node/video/Vidu", + description="Generate a video from start and end frames and a prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "first_frame", + tooltip="Start frame", + ), + comfy_io.Image.Input( + "end_frame", + tooltip="End frame", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: torch.Tensor, + end_frame: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = [ + (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + for frame in (first_frame, end_frame) + ] + results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + ViduTextToVideoNode, + ViduImageToVideoNode, + ViduReferenceVideoNode, + ViduStartEndToVideoNode, + ] + +async def comfy_entrypoint() -> ViduExtension: + return ViduExtension() diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 031b9fbd3..606b794bf 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -53,6 +53,53 @@ def validate_image_aspect_ratio( ) +def validate_image_aspect_ratio_range( + image: torch.Tensor, + min_ratio: tuple[float, float], # e.g. (1, 4) + max_ratio: tuple[float, float], # e.g. (4, 1) + *, + strict: bool = True, # True -> (min, max); False -> [min, max] +) -> float: + a1, b1 = min_ratio + a2, b2 = max_ratio + if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0: + raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).") + lo, hi = (a1 / b1), (a2 / b2) + if lo > hi: + lo, hi = hi, lo + a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text + w, h = get_image_dimensions(image) + if w <= 0 or h <= 0: + raise ValueError(f"Invalid image dimensions: {w}x{h}") + ar = w / h + ok = (lo < ar < hi) if strict else (lo <= ar <= hi) + if not ok: + op = "<" if strict else "≤" + raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}") + return ar + + +def validate_aspect_ratio_closeness( + start_img, + end_img, + min_rel: float, + max_rel: float, + *, + strict: bool = False, # True => exclusive, False => inclusive +) -> None: + w1, h1 = get_image_dimensions(start_img) + w2, h2 = get_image_dimensions(end_img) + if min(w1, h1, w2, h2) <= 0: + raise ValueError("Invalid image dimensions") + ar1 = w1 / h1 + ar2 = w2 / h2 + # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1) + closeness = max(ar1, ar2) / min(ar1, ar2) + limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 + if (closeness >= limit) if strict else (closeness > limit): + raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.") + + def validate_video_dimensions( video: VideoInput, min_width: Optional[int] = None, @@ -98,3 +145,9 @@ def validate_video_duration( raise ValueError( f"Video duration must be at most {max_duration}s, got {duration}s" ) + + +def get_number_of_images(images): + if isinstance(images, torch.Tensor): + return images.shape[0] if images.ndim >= 4 else 1 + return len(images) diff --git a/nodes.py b/nodes.py index b3fa9c51a..35dda1b19 100644 --- a/nodes.py +++ b/nodes.py @@ -2351,6 +2351,7 @@ async def init_builtin_api_nodes(): "nodes_moonvalley.py", "nodes_rodin.py", "nodes_gemini.py", + "nodes_vidu.py", ] if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): From bddd69618bf4463209c3681babfcbebd9b9aed85 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:49:01 -0700 Subject: [PATCH 081/325] Change the TextEncodeQwenImageEdit node to use logic closer to reference. (#9432) --- comfy_extras/nodes_qwen.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index b5088fae2..fff89556f 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,25 +1,6 @@ import node_helpers import comfy.utils - -PREFERRED_QWENIMAGE_RESOLUTIONS = [ - (672, 1568), - (688, 1504), - (720, 1456), - (752, 1392), - (800, 1328), - (832, 1248), - (880, 1184), - (944, 1104), - (1024, 1024), - (1104, 944), - (1184, 880), - (1248, 832), - (1328, 800), - (1392, 752), - (1456, 720), - (1504, 688), - (1568, 672), -] +import math class TextEncodeQwenImageEdit: @@ -42,13 +23,17 @@ class TextEncodeQwenImageEdit: if image is None: images = [] else: - images = [image] + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + image = s.movedim(1, -1) + images = [image[:, :, :, :3]] if vae is not None: - width = image.shape[2] - height = image.shape[1] - aspect_ratio = width / height - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS) - image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) ref_latent = vae.encode(image[:, :, :, :3]) tokens = clip.tokenize(prompt, images=images) From dfa791eb4bfcaac3eb9b2b33fa15ae5a25589bb8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:47:42 -0700 Subject: [PATCH 082/325] Rope fix for qwen vl. (#9435) --- comfy/text_encoders/llama.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 9d90d5a61..4c976058f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -27,6 +27,7 @@ class Llama2Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = False + rope_dims = None @dataclass class Qwen25_3BConfig: @@ -44,6 +45,7 @@ class Qwen25_3BConfig: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = None @dataclass class Qwen25_7BVLI_Config: @@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = [16, 24, 24] @dataclass class Gemma2_2B_Config: @@ -78,6 +81,7 @@ class Gemma2_2B_Config: rms_norm_add = True mlp_activation = "gelu_pytorch_tanh" qkv_bias = False + rope_dims = None class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -102,7 +106,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, position_ids, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) @@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + return (cos, sin) def apply_rope(xq, xk, freqs_cis): - cos = freqs_cis[0].unsqueeze(1) - sin = freqs_cis[1].unsqueeze(1) + cos = freqs_cis[0] + sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) return q_embed, k_embed @@ -292,6 +304,7 @@ class Llama2_(nn.Module): freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, + self.config.rope_dims, device=x.device) mask = None From 7cd2c4bd6ab20f35a6bb1b1f2252c3ea16da4777 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 21:45:27 -0700 Subject: [PATCH 083/325] Qwen rotary embeddings should now match reference code. (#9437) --- comfy/ldm/qwen_image/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index a3c726299..bf3940313 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -349,8 +349,8 @@ class QwenImageTransformer2DModel(nn.Module): img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( @@ -396,7 +396,7 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) From 5a8f502db5889873ffa13132b603b7b6daac605a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:08:11 -0700 Subject: [PATCH 084/325] Disable prompt weights for qwen. (#9438) --- comfy/sd1_clip.py | 5 ++++- comfy/text_encoders/qwen_image.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1e8adbe69..f8a7c2a1b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -534,7 +534,10 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - parsed_weights = token_weights(text, 1.0) + if kwargs.get("disable_weights", False): + parsed_weights = [(text, 1.0)] + else: + parsed_weights = token_weights(text, 1.0) # tokenize words tokens = [] diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index f07318d6c..6646b1003 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -15,7 +15,7 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" - self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): if llama_template is None: @@ -25,7 +25,7 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) key_name = next(iter(tokens)) embed_count = 0 qwen_tokens = tokens[key_name] From 8d38ea3bbf7e77ed7e7aee401b157dab211c5307 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:58:54 -0700 Subject: [PATCH 085/325] Fix bf16 precision issue with qwen image embeddings. (#9441) --- comfy/ldm/qwen_image/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index bf3940313..49f66b90a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -347,7 +347,7 @@ class QwenImageTransformer2DModel(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device) img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) @@ -397,9 +397,10 @@ class QwenImageTransformer2DModel(nn.Module): img_ids = torch.cat([img_ids, kontext_ids], dim=1) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) From 2f52e8f05f2039dd67e9b9783b8397350a548b95 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 20 Aug 2025 15:15:09 +0800 Subject: [PATCH 086/325] Bump template to 0.1.62 (#9419) * Bump template to 0.1.61 * Bump template to 0.1.62 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c7a5c47ab..8d928d826 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.9 -comfyui-workflow-templates==0.1.60 +comfyui-workflow-templates==0.1.62 comfyui-embedded-docs==0.2.6 torch torchsde From 7139d6d93fc7b5481a69b687080bd36f7b531c46 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Aug 2025 03:15:30 -0400 Subject: [PATCH 087/325] ComfyUI version 0.3.51 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 29ec07ca6..65f06cf37 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.50" +__version__ = "0.3.51" diff --git a/pyproject.toml b/pyproject.toml index 659b5730a..ecbf04303 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.50" +version = "0.3.51" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From fe01885acf892de636b1b2743903812099bd42e3 Mon Sep 17 00:00:00 2001 From: Harel Cain Date: Wed, 20 Aug 2025 09:33:10 +0200 Subject: [PATCH 088/325] LTXV: fix key frame noise mask dimensions for when real noise mask exists (#9425) --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b5058667a..f82337a67 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -166,7 +166,7 @@ class LTXVAddGuide: negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), 1.0 - strength, dtype=noise_mask.dtype, device=noise_mask.device, From e73a9dbe30434280c69d852ea78cc4bf88bfd501 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:34:13 -0700 Subject: [PATCH 089/325] Add that qwen edit model is supported to readme. (#9463) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index fa99a8cbe..79a8a8c79 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11) + - [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) From 0963493a9c3b6565f8537288a0fb90991391ec41 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:26:37 -0700 Subject: [PATCH 090/325] Support for Qwen Diffsynth Controlnets canny and depth. (#9465) These are not real controlnets but actually a patch on the model so they will be treated as such. Put them in the models/model_patches/ folder. Use the new ModelPatchLoader and QwenImageDiffsynthControlnet nodes. --- comfy/ldm/qwen_image/model.py | 7 + comfy/model_management.py | 8 +- comfy/model_patcher.py | 27 ++++ comfy_api/latest/_io.py | 4 + comfy_extras/nodes_model_patch.py | 138 ++++++++++++++++++++ models/model_patches/put_model_patches_here | 0 nodes.py | 1 + 7 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_model_patch.py create mode 100644 models/model_patches/put_model_patches_here diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 49f66b90a..2503583cb 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -416,6 +416,7 @@ class QwenImageTransformer2DModel(nn.Module): ) patches_replace = transformer_options.get("patches_replace", {}) + patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): @@ -436,6 +437,12 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb=image_rotary_emb, ) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a9f18068..d08aee1fe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models = set(models) + models_temp = set() + for m in models: + models_temp.add(m) + for mm in m.model_patches_models(): + models_temp.add(mm) + + models = models_temp models_to_load = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 52e76b5f3..a944cb421 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -430,6 +430,9 @@ class ModelPatcher: def set_model_forward_timestep_embed_patch(self, patch): self.set_model_patch(patch, "forward_timestep_embed_patch") + def set_model_double_block_patch(self, patch): + self.set_model_patch(patch, "double_block") + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -486,6 +489,30 @@ class ModelPatcher: if hasattr(wrap_func, "to"): self.model_options["model_function_wrapper"] = wrap_func.to(device) + def model_patches_models(self): + to = self.model_options["transformer_options"] + models = [] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "models"): + models += patch_list[i].models() + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "models"): + models += patch_list[k].models() + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, "models"): + models += wrap_func.models() + + return models + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ec1efb51d..a3a21facc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO): class AnyType(ComfyTypeIO): Type = Any +@comfytype(io_type="MODEL_PATCH") +class MODEL_PATCH(ComfyTypeIO): + Type = Any + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py new file mode 100644 index 000000000..bb239bc45 --- /dev/null +++ b/comfy_extras/nodes_model_patch.py @@ -0,0 +1,138 @@ +import torch +import folder_paths +import comfy.utils +import comfy.ops +import comfy.model_management +import comfy.ldm.common_dit +import comfy.latent_formats + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None): + super().__init__() + self.x_rms = operations.RMSNorm(dim, eps=1e-6) + self.y_rms = operations.RMSNorm(dim, eps=1e-6) + self.input_proj = operations.Linear(dim, dim) + self.act = torch.nn.GELU() + self.output_proj = operations.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + device=None, dtype=None, operations=None + ): + super().__init__() + self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) + self.controlnet_blocks = torch.nn.ModuleList( + [ + BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + def process_input_latent_image(self, latent_image): + latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + patch_size = 2 + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + return self.img_in(hidden_states) + + def control_block(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) + + +class ModelPatchLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), + }} + RETURN_TYPES = ("MODEL_PATCH",) + FUNCTION = "load_model_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders" + + def load_model_patch(self, name): + model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) + sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) + dtype = comfy.utils.weight_dtype(sd) + # TODO: this node will work with more types of model patches + model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) + model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + return (model,) + + +class DiffSynthCnetPatch: + def __init__(self, model_patch, vae, image, strength): + self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + block_index = kwargs.get("block_index") + if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]: + spacial_compression = self.vae.spacial_compression_encode() + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + comfy.model_management.load_models_gpu(loaded_models) + + img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) + kwargs['img'] = img + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + +class QwenImageDiffsynthControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "diffsynth_controlnet" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders/qwen" + + def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + model_patched = model.clone() + image = image[:, :, :, :3] + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + return (model_patched,) + + +NODE_CLASS_MAPPINGS = { + "ModelPatchLoader": ModelPatchLoader, + "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, +} diff --git a/models/model_patches/put_model_patches_here b/models/model_patches/put_model_patches_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 35dda1b19..9681750d3 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_model_patch.py" ] import_failed = [] From 0737b7e0d245de20192064da4888debbef3241c2 Mon Sep 17 00:00:00 2001 From: saurabh-pingale Date: Thu, 21 Aug 2025 07:57:57 +0530 Subject: [PATCH 091/325] fix(userdata): catch invalid workflow filenames (#9434) (#9445) --- app/user_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/app/user_manager.py b/app/user_manager.py index 0ec3e46ea..a2d376c0c 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -363,10 +363,17 @@ class UserManager(): if not overwrite and os.path.exists(path): return web.Response(status=409, text="File already exists") - body = await request.read() + try: + body = await request.read() - with open(path, "wb") as f: - f.write(body) + with open(path, "wb") as f: + f.write(body) + except OSError as e: + logging.warning(f"Error saving file '{path}': {e}") + return web.Response( + status=400, + reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|" + ) user_path = self.get_request_user_filepath(request, None) if full_info: From 9fa1036f60b5264302072453be524aa55928bbaf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:09:35 -0700 Subject: [PATCH 092/325] Forgot this. (#9470) --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 9ec952940..b34af39e8 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -46,6 +46,8 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")] folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") From 1b2de2642d38099acdde7c460d133d93e91074f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 21:33:49 -0700 Subject: [PATCH 093/325] Support diffsynth inpaint controlnet (model patch). (#9471) --- comfy_extras/nodes_model_patch.py | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index bb239bc45..3eaada9bc 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -35,6 +35,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): device=None, dtype=None, operations=None ): super().__init__() + self.additional_in_dim = additional_in_dim self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) self.controlnet_blocks = torch.nn.ModuleList( [ @@ -44,7 +45,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): ) def process_input_latent_image(self, latent_image): - latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16]) patch_size = 2 hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) orig_shape = hidden_states.shape @@ -73,19 +74,33 @@ class ModelPatchLoader: sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) # TODO: this node will work with more types of model patches - model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) class DiffSynthCnetPatch: - def __init__(self, model_patch, vae, image, strength): - self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + def __init__(self, model_patch, vae, image, strength, mask=None): self.model_patch = model_patch self.vae = vae self.image = image self.strength = strength + self.mask = mask + self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) + + def encode_latent_cond(self, image): + latent_image = self.vae.encode(image) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + + return torch.cat([latent_image, mask_], dim=1) + else: + return latent_image def __call__(self, kwargs): x = kwargs.get("x") @@ -95,7 +110,7 @@ class DiffSynthCnetPatch: spacial_compression = self.vae.spacial_compression_encode() image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) comfy.model_management.load_models_gpu(loaded_models) img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) @@ -118,17 +133,25 @@ class QwenImageDiffsynthControlnet: "vae": ("VAE",), "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + }, + "optional": {"mask": ("MASK",)}} RETURN_TYPES = ("MODEL",) FUNCTION = "diffsynth_controlnet" EXPERIMENTAL = True CATEGORY = "advanced/loaders/qwen" - def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): model_patched = model.clone() image = image[:, :, :, :3] - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + if mask.ndim == 4: + mask = mask.unsqueeze(2) + mask = 1.0 - mask + + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From bc49106837b627eb657fc86f2e475770ac5ce68a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:03:57 +0300 Subject: [PATCH 094/325] convert String nodes to V3 schema (#9370) --- comfy_extras/nodes_string.py | 449 ++++++++++++++++++----------------- 1 file changed, 237 insertions(+), 212 deletions(-) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index b1a8ceef0..571d89f62 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,77 +1,91 @@ import re +from typing_extensions import override -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import ComfyExtension, io -class StringConcatenate(): + +class StringConcatenate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ""}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringConcatenate", + display_name="Concatenate", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.String.Input("delimiter", multiline=False, default=""), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, delimiter, **kwargs): - return delimiter.join((string_a, string_b)), - -class StringSubstring(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "start": (IO.INT, {}), - "end": (IO.INT, {}), - } - } + def execute(cls, string_a, string_b, delimiter): + return io.NodeOutput(delimiter.join((string_a, string_b))) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, start, end, **kwargs): - return string[start:end], - -class StringLength(): +class StringSubstring(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringSubstring", + display_name="Substring", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Int.Input("start"), + io.Int.Input("end"), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.INT,) - RETURN_NAMES = ("length",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, **kwargs): - length = len(string) - - return length, - -class CaseConverter(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]}) - } - } + def execute(cls, string, start, end): + return io.NodeOutput(string[start:end]) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, mode, **kwargs): +class StringLength(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringLength", + display_name="Length", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + ], + outputs=[ + io.Int.Output(display_name="length"), + ] + ) + + @classmethod + def execute(cls, string): + return io.NodeOutput(len(string)) + + +class CaseConverter(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CaseConverter", + display_name="Case Converter", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, string, mode): if mode == "UPPERCASE": result = string.upper() elif mode == "lowercase": @@ -83,24 +97,27 @@ class CaseConverter(): else: result = string - return result, + return io.NodeOutput(result) -class StringTrim(): +class StringTrim(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringTrim", + display_name="Trim", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["Both", "Left", "Right"]), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + @classmethod + def execute(cls, string, mode): if mode == "Both": result = string.strip() elif mode == "Left": @@ -110,70 +127,78 @@ class StringTrim(): else: result = string - return result, + return io.NodeOutput(result) -class StringReplace(): + +class StringReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "find": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringReplace", + display_name="Replace", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("find", multiline=True), + io.String.Input("replace", multiline=True), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, find, replace, **kwargs): - result = string.replace(find, replace) - return result, - - -class StringContains(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "substring": (IO.STRING, {"multiline": True}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def execute(cls, string, find, replace): + return io.NodeOutput(string.replace(find, replace)) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("contains",) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, substring, case_sensitive, **kwargs): +class StringContains(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringContains", + display_name="Contains", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("substring", multiline=True), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(display_name="contains"), + ] + ) + + @classmethod + def execute(cls, string, substring, case_sensitive): if case_sensitive: contains = substring in string else: contains = substring.lower() in string.lower() - return contains, + return io.NodeOutput(contains) -class StringCompare(): +class StringCompare(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringCompare", + display_name="Compare", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): + @classmethod + def execute(cls, string_a, string_b, mode, case_sensitive): if case_sensitive: a = string_a b = string_b @@ -182,31 +207,34 @@ class StringCompare(): b = string_b.lower() if mode == "Equal": - return a == b, + return io.NodeOutput(a == b) elif mode == "Starts With": - return a.startswith(b), + return io.NodeOutput(a.startswith(b)) elif mode == "Ends With": - return a.endswith(b), + return io.NodeOutput(a.endswith(b)) -class RegexMatch(): + +class RegexMatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexMatch", + display_name="Regex Match", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + ], + outputs=[ + io.Boolean.Output(display_name="matches"), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("matches",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall): flags = 0 if case_insensitive: @@ -223,29 +251,32 @@ class RegexMatch(): except re.error: result = False - return result, + return io.NodeOutput(result) -class RegexExtract(): +class RegexExtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}), - "group_index": (IO.INT, {"default": 1, "min": 0, "max": 100}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexExtract", + display_name="Regex Extract", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + io.Int.Input("group_index", default=1, min=0, max=100), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index): join_delimiter = "\n" flags = 0 @@ -294,32 +325,33 @@ class RegexExtract(): except re.error: result = "" - return result, + return io.NodeOutput(result) -class RegexReplace(): - DESCRIPTION = "Find and replace text using regex patterns." +class RegexReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}), - }, - "optional": { - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), - "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexReplace", + display_name="Regex Replace", + category="utils/string", + description="Find and replace text using regex patterns.", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.String.Input("replace", multiline=True), + io.Boolean.Input("case_insensitive", default=True, optional=True), + io.Boolean.Input("multiline", default=False, optional=True), + io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), + io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0): flags = 0 if case_insensitive: @@ -329,32 +361,25 @@ class RegexReplace(): if dotall: flags |= re.DOTALL result = re.sub(regex_pattern, replace, string, count=count, flags=flags) - return result, + return io.NodeOutput(result) -NODE_CLASS_MAPPINGS = { - "StringConcatenate": StringConcatenate, - "StringSubstring": StringSubstring, - "StringLength": StringLength, - "CaseConverter": CaseConverter, - "StringTrim": StringTrim, - "StringReplace": StringReplace, - "StringContains": StringContains, - "StringCompare": StringCompare, - "RegexMatch": RegexMatch, - "RegexExtract": RegexExtract, - "RegexReplace": RegexReplace, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "StringConcatenate": "Concatenate", - "StringSubstring": "Substring", - "StringLength": "Length", - "CaseConverter": "Case Converter", - "StringTrim": "Trim", - "StringReplace": "Replace", - "StringContains": "Contains", - "StringCompare": "Compare", - "RegexMatch": "Regex Match", - "RegexExtract": "Regex Extract", - "RegexReplace": "Regex Replace", -} +class StringExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StringConcatenate, + StringSubstring, + StringLength, + CaseConverter, + StringTrim, + StringReplace, + StringContains, + StringCompare, + RegexMatch, + RegexExtract, + RegexReplace, + ] + +async def comfy_entrypoint() -> StringExtension: + return StringExtension() From bab08f40d10c8737c3424e35bbff873bcb2333bd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:05:36 +0300 Subject: [PATCH 095/325] v3 nodes (part a) (#9149) --- comfy_extras/nodes_ace.py | 80 +++++++----- comfy_extras/nodes_advanced_samplers.py | 88 +++++++------ comfy_extras/nodes_apg.py | 72 +++++++---- comfy_extras/nodes_attention_multiply.py | 154 ++++++++++++++--------- 4 files changed, 239 insertions(+), 155 deletions(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfec15a2..1409233c9 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -1,49 +1,63 @@ import torch +from typing_extensions import override + import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class TextEncodeAceStepAudio: + +class TextEncodeAceStepAudio(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, clip, tags, lyrics, lyrics_strength): + @classmethod + def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return (conditioning, ) + return io.NodeOutput(conditioning) -class EmptyAceStepLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyAceStepLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStepLatentAudio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=self.device) - return ({"samples": latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) -NODE_CLASS_MAPPINGS = { - "TextEncodeAceStepAudio": TextEncodeAceStepAudio, - "EmptyAceStepLatentAudio": EmptyAceStepLatentAudio, -} +class AceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeAceStepAudio, + EmptyAceStepLatentAudio, + ] + +async def comfy_entrypoint() -> AceExtension: + return AceExtension() diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 5fbb096fb..5532ffe6a 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -1,8 +1,13 @@ +import numpy as np +import torch +from tqdm.auto import trange +from typing_extensions import override + +import comfy.model_patcher import comfy.samplers import comfy.utils -import torch -import numpy as np -from tqdm.auto import trange +from comfy.k_diffusion.sampling import to_d +from comfy_api.latest import ComfyExtension, io @torch.no_grad() @@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable return x -class SamplerLCMUpscale: - upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] +class SamplerLCMUpscale(io.ComfyNode): + UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] @classmethod - def INPUT_TYPES(s): - return {"required": - {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), - "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), - "upscale_method": (s.upscale_methods,), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCMUpscale", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[io.Sampler.Output()], + ) - FUNCTION = "get_sampler" - - def get_sampler(self, scale_ratio, scale_steps, upscale_method): + @classmethod + def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput: if scale_steps < 0: scale_steps = None sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) - return (sampler, ) + return io.NodeOutput(sampler) -from comfy.k_diffusion.sampling import to_d -import comfy.model_patcher @torch.no_grad() def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): @@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No return x -class SamplerEulerCFGpp: +class SamplerEulerCFGpp(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"version": (["regular", "alternative"],),} - } - RETURN_TYPES = ("SAMPLER",) - # CATEGORY = "sampling/custom_sampling/samplers" - CATEGORY = "_for_testing" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerCFGpp", + display_name="SamplerEulerCFG++", + category="_for_testing", # "sampling/custom_sampling/samplers" + inputs=[ + io.Combo.Input("version", options=["regular", "alternative"]), + ], + outputs=[io.Sampler.Output()], + is_experimental=True, + ) - FUNCTION = "get_sampler" - - def get_sampler(self, version): + @classmethod + def execute(cls, version) -> io.NodeOutput: if version == "alternative": sampler = comfy.samplers.KSAMPLER(sample_euler_pp) else: sampler = comfy.samplers.ksampler("euler_cfg_pp") - return (sampler, ) + return io.NodeOutput(sampler) -NODE_CLASS_MAPPINGS = { - "SamplerLCMUpscale": SamplerLCMUpscale, - "SamplerEulerCFGpp": SamplerEulerCFGpp, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerCFGpp": "SamplerEulerCFG++", -} +class AdvancedSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerLCMUpscale, + SamplerEulerCFGpp, + ] + +async def comfy_entrypoint() -> AdvancedSamplersExtension: + return AdvancedSamplersExtension() diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 25b21b1b8..f27ae7da8 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -1,4 +1,8 @@ import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def project(v0, v1): v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) @@ -6,22 +10,45 @@ def project(v0, v1): v0_orthogonal = v0 - v0_parallel return v0_parallel, v0_orthogonal -class APG: +class APG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), - "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "sampling/custom_sampling" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="APG", + display_name="Adaptive Projected Guidance", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "eta", + default=1.0, + min=-10.0, + max=10.0, + step=0.01, + tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + ), + io.Float.Input( + "norm_threshold", + default=5.0, + min=0.0, + max=50.0, + step=0.1, + tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + ), + io.Float.Input( + "momentum", + default=0.0, + min=-5.0, + max=1.0, + step=0.01, + tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + ), + ], + outputs=[io.Model.Output()], + ) - def patch(self, model, eta, norm_threshold, momentum): + @classmethod + def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: running_avg = 0 prev_sigma = None @@ -65,12 +92,15 @@ class APG: m = model.clone() m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "APG": APG, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "APG": "Adaptive Projected Guidance", -} +class ApgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + APG, + ] + +async def comfy_entrypoint() -> ApgExtension: + return ApgExtension() diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 4747eb395..c0e494c2a 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -1,3 +1,7 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def attention_multiply(attn, model, q, k, v, out): m = model.clone() @@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out): return m -class UNetSelfAttentionMultiply: +class UNetSelfAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetSelfAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn1", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class UNetCrossAttentionMultiply: + +class UNetCrossAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetCrossAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn2", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class CLIPAttentionMultiply: + +class CLIPAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Clip.Input("clip"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, clip, q, k, v, out): + @classmethod + def execute(cls, clip, q, k, v, out) -> io.NodeOutput: m = clip.clone() sd = m.patcher.model_state_dict() @@ -79,23 +97,28 @@ class CLIPAttentionMultiply: m.add_patches({key: (None,)}, 0.0, v) if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): m.add_patches({key: (None,)}, 0.0, out) - return (m, ) + return io.NodeOutput(m) -class UNetTemporalAttentionMultiply: + +class UNetTemporalAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetTemporalAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + @classmethod + def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput: m = model.clone() sd = model.model_state_dict() @@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply: m.add_patches({k: (None,)}, 0.0, cross_temporal) else: m.add_patches({k: (None,)}, 0.0, cross_structural) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, - "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, - "CLIPAttentionMultiply": CLIPAttentionMultiply, - "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, -} + +class AttentionMultiplyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UNetSelfAttentionMultiply, + UNetCrossAttentionMultiply, + CLIPAttentionMultiply, + UNetTemporalAttentionMultiply, + ] + +async def comfy_entrypoint() -> AttentionMultiplyExtension: + return AttentionMultiplyExtension() From eb39019daae96128ee848d0b7837ede299518a7c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:06:13 +0300 Subject: [PATCH 096/325] [V3] convert Google Veo API node to the V3 schema (#9272) * convert Google Veo API node to the V3 schema * use own full io.Schema for Veo3VideoGenerationNode * fixed typo * use auth_kwargs instead of auth_token/comfy_api_key --- comfy_api_nodes/nodes_veo2.py | 332 +++++++++++++++++++--------------- 1 file changed, 190 insertions(+), 142 deletions(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e25dab2f5..251aecd42 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,17 +1,18 @@ -import io import logging import base64 import aiohttp import torch +from io import BytesIO from typing import Optional +from typing_extensions import override -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( VeoGenVidRequest, VeoGenVidResponse, VeoGenVidPollRequest, - VeoGenVidPollResponse + VeoGenVidPollResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( downscale_image_tensor, - tensor_to_base64_string + tensor_to_base64_string, ) AVERAGE_DURATION_VIDEO_GEN = 32 @@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona return None -class VeoVideoGenerationNode(ComfyNodeABC): +class VeoVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos from text prompts using Google's Veo API. @@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text description of the video", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="VeoVideoGenerationNode", + display_name="Google Veo 2 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 2 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", ), - "aspect_ratio": ( - IO.COMBO, - { - "options": ["16:9", "9:16"], - "default": "16:9", - "tooltip": "Aspect ratio of the output video", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Negative text prompt to guide what to avoid in the video", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, ), - "duration_seconds": ( - IO.INT, - { - "default": 5, - "min": 5, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds", - }, + comfy_io.Int.Input( + "duration_seconds", + default=5, + min=5, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, ), - "enhance_prompt": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Whether to enhance the prompt with AI assistance", - } + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, ), - "person_generation": ( - IO.COMBO, - { - "options": ["ALLOW", "BLOCK"], - "default": "ALLOW", - "tooltip": "Whether to allow generating people in the video", - }, + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFF, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "Seed for video generation (0 for random)", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, ), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image to guide video generation", - }), - "model": ( - IO.COMBO, - { - "options": ["veo-2.0-generate-001"], - "default": "veo-2.0-generate-001", - "tooltip": "Veo 2 model to use for video generation", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Combo.Input( + "model", + options=["veo-2.0-generate-001"], + default="veo-2.0-generate-001", + tooltip="Veo 2 model to use for video generation", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "generate_video" - CATEGORY = "api node/video/Veo" - DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" - API_NODE = True - - async def generate_video( - self, + @classmethod + async def execute( + cls, prompt, aspect_ratio="16:9", negative_prompt="", @@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC): image=None, model="veo-2.0-generate-001", generate_audio=False, - unique_id: Optional[str] = None, - **kwargs, ): # Prepare the instances for the request instances = [] @@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): if "veo-3.0" in model: parameters["generateAudio"] = generate_audio + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Initial request to start video generation initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): instances=instances, parameters=parameters ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) initial_response = await initial_operation.execute() @@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): request=VeoGenVidPollRequest( operationName=operation_name ), - auth_kwargs=kwargs, + auth_kwargs=auth, poll_interval=5.0, result_url_extractor=get_video_url_from_response, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) @@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): logging.info("Video generation completed successfully") # Convert video data to BytesIO object - video_io = io.BytesIO(video_data) + video_io = BytesIO(video_data) # Return VideoFromFile object - return (VideoFromFile(video_io),) + return comfy_io.NodeOutput(VideoFromFile(video_io)) class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): """ @classmethod - def INPUT_TYPES(s): - parent_input = super().INPUT_TYPES() - - # Update model options for Veo 3 - parent_input["optional"]["model"] = ( - IO.COMBO, - { - "options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], - "default": "veo-3.0-generate-001", - "tooltip": "Veo 3 model to use for video generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="Veo3VideoGenerationNode", + display_name="Google Veo 3 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 3 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, + ), + comfy_io.Int.Input( + "duration_seconds", + default=8, + min=8, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + optional=True, + ), + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, + ), + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, + ), + comfy_io.Combo.Input( + "model", + options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], + default="veo-3.0-generate-001", + tooltip="Veo 3 model to use for video generation", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + tooltip="Generate audio for the video. Supported by all Veo 3 models.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - # Add generateAudio parameter - parent_input["optional"]["generate_audio"] = ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Generate audio for the video. Supported by all Veo 3 models.", - } - ) - # Update duration constraints for Veo 3 (only 8 seconds supported) - parent_input["optional"]["duration_seconds"] = ( - IO.INT, - { - "default": 8, - "min": 8, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)", - }, - ) +class VeoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + VeoVideoGenerationNode, + Veo3VideoGenerationNode, + ] - return parent_input - - -# Register the nodes -NODE_CLASS_MAPPINGS = { - "VeoVideoGenerationNode": VeoVideoGenerationNode, - "Veo3VideoGenerationNode": Veo3VideoGenerationNode, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "VeoVideoGenerationNode": "Google Veo 2 Video Generation", - "Veo3VideoGenerationNode": "Google Veo 3 Video Generation", -} +async def comfy_entrypoint() -> VeoExtension: + return VeoExtension() From 7ed73d12d13c2c389e0469c46c2db635a7d74278 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:06:51 +0300 Subject: [PATCH 097/325] [V3] convert Ideogram API nodes to the V3 schema (#9278) * convert Ideogram API nodes to the V3 schema * use auth_kwargs instead of auth_token/comfy_api_key --- comfy_api_nodes/nodes_ideogram.py | 536 ++++++++++++++---------------- 1 file changed, 257 insertions(+), 279 deletions(-) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index db24e6da4..d28895f3e 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,8 +1,8 @@ -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from inspect import cleandoc +from io import BytesIO +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from PIL import Image import numpy as np -import io import torch from comfy_api_nodes.apis import ( IdeogramGenerateRequest, @@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id): PromptServer.instance.send_progress_text(urls_text, node_id) -class IdeogramV1(ComfyNodeABC): - """ - Generates images using the Ideogram V1 model. - """ - - def __init__(self): - pass +class IdeogramV1(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV1", + display_name="Ideogram V1", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V1 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, - unique_id=None, - **kwargs, ): # Determine the model based on turbo setting aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -364,7 +357,7 @@ class IdeogramV1(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV2(ComfyNodeABC): - """ - Generates images using the Ideogram V2 model. - """ - - def __init__(self): - pass +class IdeogramV2(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV2", + display_name="Ideogram V2", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V2 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": list(V1_V1_RES_MAP.keys()), - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=list(V1_V1_RES_MAP.keys()), + default="Auto", + tooltip="The resolution for image generation. " + "If not set to AUTO, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "style_type": ( - IO.COMBO, - { - "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], - "default": "NONE", - "tooltip": "Style type for generation (V2 only)", - }, + comfy_io.Combo.Input( + "style_type", + options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], + default="NONE", + tooltip="Style type for generation (V2 only)", + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), #"color_palette": ( # IO.STRING, @@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC): # "tooltip": "Color palette preset name or hex colors with weights", # }, #), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", - unique_id=None, - **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) resolution = V1_V1_RES_MAP.get(resolution, None) @@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -540,7 +525,7 @@ class IdeogramV2(ComfyNodeABC): color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV3(ComfyNodeABC): - """ - Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask. - """ - def __init__(self): - pass +class IdeogramV3(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation or editing", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV3", + display_name="Ideogram V3", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V3 model. " + "Supports both regular image generation from text prompts and image editing with mask.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or editing", ), - }, - "optional": { - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, + comfy_io.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, ), - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V3_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V3_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": V3_RESOLUTIONS, - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=V3_RESOLUTIONS, + default="Auto", + tooltip="The resolution for image generation. " + "If not set to Auto, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "rendering_speed": ( - IO.COMBO, - { - "options": ["BALANCED", "TURBO", "QUALITY"], - "default": "BALANCED", - "tooltip": "Controls the trade-off between generation speed and quality", - }, + comfy_io.Combo.Input( + "rendering_speed", + options=["BALANCED", "TURBO", "QUALITY"], + default="BALANCED", + tooltip="Controls the trade-off between generation speed and quality", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, image=None, mask=None, @@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC): seed=0, num_images=1, rendering_speed="BALANCED", - unique_id=None, - **kwargs, ): + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Check if both image and mask are provided for editing mode if image is not None and mask is not None: # Edit mode @@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC): # Process image img_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(img_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr @@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC): # Process mask - white areas will be replaced mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_byte_arr = io.BytesIO() + mask_byte_arr = BytesIO() mask_img.save(mask_byte_arr, format="PNG") mask_byte_arr.seek(0) mask_binary = mask_byte_arr @@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC): "mask": mask_binary, }, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) elif image is not None or mask is not None: @@ -770,7 +748,7 @@ class IdeogramV3(ComfyNodeABC): response_model=IdeogramGenerateResponse, ), request=gen_request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) # Execute the operation and process response @@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -NODE_CLASS_MAPPINGS = { - "IdeogramV1": IdeogramV1, - "IdeogramV2": IdeogramV2, - "IdeogramV3": IdeogramV3, -} +class IdeogramExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + IdeogramV1, + IdeogramV2, + IdeogramV3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "IdeogramV1": "Ideogram V1", - "IdeogramV2": "Ideogram V2", - "IdeogramV3": "Ideogram V3", -} +async def comfy_entrypoint() -> IdeogramExtension: + return IdeogramExtension() From f7bd5e58dd03e799e02f6851b84b51e14ad0da7b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 20:18:04 -0700 Subject: [PATCH 098/325] Make it easier to implement future qwen controlnets. (#9485) --- comfy/controlnet.py | 4 ++-- comfy/ldm/qwen_image/model.py | 16 +++++++++++++--- comfy/model_detection.py | 2 ++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 988acdb57..6cb69dcdf 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -236,11 +236,11 @@ class ControlNet(ControlBase): self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio + compression_ratio *= self.vae.spacial_compression_encode() else: if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2503583cb..d0e39833a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), image_model=None, + final_layer=True, dtype=None, device=None, operations=None, @@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module): super().__init__() self.dtype = dtype self.patch_size = patch_size + self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) - self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - self.gradient_checkpointing = False + if final_layer: + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape @@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance: torch.Tensor = None, ref_latents=None, transformer_options={}, + control=None, **kwargs ): timestep = timesteps @@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = out["img"] encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states += add + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2bec0541e..0caff53e0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" + dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: From ff57793659702d502506047445f0972b10b6b9fe Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 21:53:11 -0700 Subject: [PATCH 099/325] Support InstantX Qwen controlnet. (#9488) --- comfy/controlnet.py | 13 +++++ comfy/ldm/qwen_image/controlnet.py | 77 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 comfy/ldm/qwen_image/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6cb69dcdf..e3dfedf55 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet +import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_qwen_instantx(sd, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + latent_format = comfy.latent_formats.Wan21() + extra_conds = [] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format else: return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet + elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data: + return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py new file mode 100644 index 000000000..92ac3cf0a --- /dev/null +++ b/comfy/ldm/qwen_image/controlnet.py @@ -0,0 +1,77 @@ +import torch +import math + +from .model import QwenImageTransformer2DModel + + +class QwenImageControlNetModel(QwenImageTransformer2DModel): + def __init__( + self, + extra_condition_channels=0, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 60 + + # controlnet_blocks + self.controlnet_blocks = torch.nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) + self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + hint=None, + transformer_options={}, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + hidden_states, img_ids, orig_shape = self.process_img(x) + hint, _, _ = self.process_img(hint) + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) + + controlnet_block_samples = () + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat + + return {"input": controlnet_block_samples[:self.main_model_double]} From 497d41fb500668635aa4a782549e9a0caa24375e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:50:35 +0300 Subject: [PATCH 100/325] feat(api-nodes): change "OpenAI Chat" display name to "OpenAI ChatGPT" (#9443) --- comfy_api_nodes/nodes_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 674c9ede0..e3b81de75 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -998,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "OpenAIDalle2": "OpenAI DALL·E 2", "OpenAIDalle3": "OpenAI DALL·E 3", "OpenAIGPTImage1": "OpenAI GPT Image 1", - "OpenAIChatNode": "OpenAI Chat", - "OpenAIInputFiles": "OpenAI Chat Input Files", - "OpenAIChatConfig": "OpenAI Chat Advanced Options", + "OpenAIChatNode": "OpenAI ChatGPT", + "OpenAIInputFiles": "OpenAI ChatGPT Input Files", + "OpenAIChatConfig": "OpenAI ChatGPT Advanced Options", } From 050c67323c33f6543309d4f09df706ec8c9a1389 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:51:14 +0300 Subject: [PATCH 101/325] feat(api-nodes): add copy button to Gemini Chat node (#9440) --- comfy_api_nodes/nodes_gemini.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ba4167a50..78c402a7a 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -5,7 +5,10 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer from __future__ import annotations +import json +import time import os +import uuid from enum import Enum from typing import Optional, Literal @@ -350,7 +353,27 @@ class GeminiNode(ComfyNodeABC): # Get result output output_text = self.get_text_from_response(response) if unique_id and output_text: - PromptServer.instance.send_progress_text(output_text, node_id=unique_id) + # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. + render_spec = { + "node_id": unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) return (output_text or "Empty response from Gemini model...",) From ca4e96a8ae6c9ee8d40fe35100ed9b2247e71e40 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 23 Aug 2025 05:40:18 +0800 Subject: [PATCH 102/325] Update template to 0.1.65 (#9501) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8d928d826..6b53fabc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.9 -comfyui-workflow-templates==0.1.62 +comfyui-workflow-templates==0.1.65 comfyui-embedded-docs==0.2.6 torch torchsde From fe31ad02768c66c61b3dc12f5d4bdfe8990ce25c Mon Sep 17 00:00:00 2001 From: contentis Date: Sat, 23 Aug 2025 01:39:15 +0200 Subject: [PATCH 103/325] Add elementwise fusions (#9495) * Add elementwise fusions * Add addcmul pattern to Qwen --- comfy/ldm/modules/diffusionmodules/mmdit.py | 12 +++++++----- comfy/ldm/qwen_image/model.py | 12 ++++++------ comfy/ldm/wan/model.py | 14 +++++++------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index eaf3e73a4..4d6beba2d 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -109,7 +109,7 @@ class PatchEmbed(nn.Module): def modulate(x, shift, scale): if shift is None: shift = torch.zeros_like(scale) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1)) ################################################################################# @@ -564,10 +564,7 @@ class DismantledBlock(nn.Module): assert not self.pre_only attn1 = self.attn.post_attention(attn) attn2 = self.attn2.post_attention(attn2) - out1 = gate_msa.unsqueeze(1) * attn1 - out2 = gate_msa2.unsqueeze(1) * attn2 - x = x + out1 - x = x + out2 + x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2) x = x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) @@ -594,6 +591,11 @@ class DismantledBlock(nn.Module): ) return self.post_attention(attn, *intermediates) +def gate_cat(x, gate_msa, gate_msa2, attn1, attn2): + out1 = gate_msa.unsqueeze(1) * attn1 + out2 = gate_msa2.unsqueeze(1) * attn2 + x = torch.stack([x, out1, out2], dim=0).sum(dim=0) + return x def block_mixing(*args, use_checkpoint=True, **kwargs): if use_checkpoint: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index d0e39833a..af00ff119 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -214,9 +214,9 @@ class QwenImageTransformerBlock(nn.Module): operations=operations, ) - def _modulate(self, x, mod_params): - shift, scale, gate = mod_params.chunk(3, dim=-1) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -248,11 +248,11 @@ class QwenImageTransformerBlock(nn.Module): img_normed2 = self.img_norm2(hidden_states) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) - hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) - encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) return encoder_hidden_states, hidden_states @@ -275,7 +275,7 @@ class LastLayer(nn.Module): def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(conditioning_embedding)) scale, shift = torch.chunk(emb, 2, dim=1) - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9d3741be3..0726b8e1b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -148,8 +148,8 @@ WAN_CROSSATTENTION_CLASSES = { def repeat_e(e, x): repeats = 1 - if e.shape[1] > 1: - repeats = x.shape[1] // e.shape[1] + if e.size(1) > 1: + repeats = x.size(1) // e.size(1) if repeats == 1: return e return torch.repeat_interleave(e, repeats, dim=1) @@ -219,15 +219,15 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs) - x = x + y * repeat_e(e[2], x) + x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) - x = x + y * repeat_e(e[5], x) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -342,7 +342,7 @@ class Head(nn.Module): else: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) + x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x)))) return x From fc247150fec502b1834390516b556a87003f1d79 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 Aug 2025 19:41:08 -0700 Subject: [PATCH 104/325] Implement EasyCache and Invent LazyCache (#9496) * Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works. * Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object * Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work * Make log statement when not skipping useful, preparing for per-cond caching * Added DIFFUSION_MODEL wrapper around forward function for wan model * Add subsampling for heuristic inputs * Add subsampling to output_prev (output_prev_subsampled now) * Properly consider conds in EasyCache logic * Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test * Change max reuse_threshold to 3.0 * Mark EasyCache/SuperEasyCache as experimental (beta) * Make Lumina2 compatible with EasyCache * Add EasyCache support for Qwen Image * Fix missing comma, curse you Cursor * Add EasyCache support to AceStep * Add EasyCache support to Chroma * Added EasyCache support to Cosmos Predict t2i * Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all * Add EasyCache support to hidream * Added EasyCache support to hunyuan video * Added EasyCache support to hunyuan3d * Added EasyCache support to LTXV (not very good, but does not crash) * Implemented EasyCache for aura_flow * Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes * Eatra logging when verbose is true for EasyCache --- comfy/ldm/ace/model.py | 24 +- comfy/ldm/aura/mmdit.py | 8 + comfy/ldm/chroma/model.py | 8 + comfy/ldm/cosmos/model.py | 38 +++ comfy/ldm/cosmos/predict2.py | 17 +- comfy/ldm/flux/model.py | 8 + comfy/ldm/hidream/model.py | 19 +- comfy/ldm/hunyuan3d/model.py | 8 + comfy/ldm/hunyuan_video/model.py | 8 + comfy/ldm/lightricks/model.py | 8 + comfy/ldm/lumina/model.py | 10 +- comfy/ldm/qwen_image/model.py | 10 +- comfy/ldm/wan/model.py | 8 + comfy/patcher_extension.py | 1 + comfy/samplers.py | 9 +- comfy_extras/nodes_easycache.py | 459 +++++++++++++++++++++++++++++++ nodes.py | 3 +- 17 files changed, 639 insertions(+), 7 deletions(-) create mode 100644 comfy_extras/nodes_easycache.py diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 12c524701..41d85eeb5 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -19,6 +19,7 @@ import torch from torch import nn import comfy.model_management +import comfy.patcher_extension from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from .attention import LinearTransformerBlock, t2i_modulate @@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module): output = self.final_layer(hidden_states, embedded_timestep, output_length) return output - def forward( + def forward(self, + x, + timestep, + attention_mask=None, + context: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + lyrics_strength=1.0, + **kwargs + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states, + controlnet_scale, lyrics_strength, **kwargs) + + def _forward( self, x, timestep, diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 1258ae11f..d7f32b5e8 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops +import comfy.patcher_extension import comfy.ldm.common_dit def modulate(x, shift, scale): @@ -436,6 +437,13 @@ class MMDiT(nn.Module): return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) def forward(self, x, timestep, context, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) # patchify x, add PE b, c, h, w = x.shape diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 06021d4f2..5cff44dc8 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch from torch import Tensor, nn from einops import rearrange, repeat +import comfy.patcher_extension import comfy.ldm.common_dit from comfy.ldm.flux.layers import ( @@ -253,6 +254,13 @@ class Chroma(nn.Module): return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 4836e0b69..53698b758 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,6 +27,8 @@ from torchvision import transforms from enum import Enum import logging +import comfy.patcher_extension + from .blocks import ( FinalLayer, GeneralDITTransformerBlock, @@ -435,6 +437,42 @@ class GeneralDIT(nn.Module): latent_condition_sigma: Optional[torch.Tensor] = None, condition_video_augment_sigma: Optional[torch.Tensor] = None, **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, + timesteps, + context, + attention_mask, + fps, + image_size, + padding_mask, + scalar_feature, + data_type, + latent_condition, + latent_condition_sigma, + condition_video_augment_sigma, + **kwargs) + + def _forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, ): """ Args: diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 316117f77..fcc83ba76 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -11,6 +11,7 @@ import math from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis from torchvision import transforms +import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention def apply_rotary_pos_emb( @@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module): ) return x_B_C_Tt_Hp_Wp - def forward( + def forward(self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, fps, padding_mask, **kwargs) + + def _forward( self, x: torch.Tensor, timesteps: torch.Tensor, diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index c4de82795..0a77fa097 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension from .layers import ( DoubleStreamBlock, @@ -214,6 +215,13 @@ class Flux(nn.Module): return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): bs, c, h_orig, w_orig = x.shape patch_size = self.patch_size diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index 0305747bf..ae49cf945 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.patcher_extension import comfy.ldm.common_dit @@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module): raise NotImplementedError return x, x_masks, img_sizes - def forward( + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + image_cond=None, + control = None, + transformer_options = {}, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options) + + def _forward( self, x: torch.Tensor, t: torch.Tensor, diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 4e18358f0..0fa5e78c1 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import ( SingleStreamBlock, timestep_embedding, ) +import comfy.patcher_extension class Hunyuan3Dv2(nn.Module): @@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module): self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): x = x.movedim(-1, -2) timestep = 1.0 - timestep txt = context diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fbd8d4196..da1011596 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -1,6 +1,7 @@ #Based on Flux code because of weird hunyuan video code license. import torch +import comfy.patcher_extension import comfy.ldm.flux.layers import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention @@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module): return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index ad9a7daea..aa2ea62b1 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,5 +1,6 @@ import torch from torch import nn +import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit from einops import rearrange @@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) + + def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) orig_shape = list(x.shape) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f8dc4d7db..e08ed817d 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +import comfy.patcher_extension def modulate(x, scale): @@ -590,8 +591,15 @@ class NextDiT(nn.Module): return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis - # def forward(self, x, t, cap_feats, cap_mask): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + # def forward(self, x, t, cap_feats, cap_mask): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index af00ff119..57a458210 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit +import comfy.patcher_extension class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -355,7 +356,14 @@ class QwenImageTransformer2DModel(nn.Module): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape - def forward( + def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + + def _forward( self, x, timesteps, diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0726b8e1b..1885d9730 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -11,6 +11,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.ldm.common_dit import comfy.model_management +import comfy.patcher_extension def sinusoidal_embedding_1d(dim, position): @@ -573,6 +574,13 @@ class WanModel(torch.nn.Module): return x def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 965958f4c..46cc7b2a8 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -50,6 +50,7 @@ class WrappersMP: OUTER_SAMPLE = "outer_sample" PREPARE_SAMPLING = "prepare_sampling" SAMPLER_SAMPLE = "sampler_sample" + PREDICT_NOISE = "predict_noise" CALC_COND_BATCH = "calc_cond_batch" APPLY_MODEL = "apply_model" DIFFUSION_MODEL = "diffusion_model" diff --git a/comfy/samplers.py b/comfy/samplers.py index d5390d64e..ec7e0b350 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -953,7 +953,14 @@ class CFGGuider: self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): - return self.predict_noise(*args, **kwargs) + return self.outer_predict_noise(*args, **kwargs) + + def outer_predict_noise(self, x, timestep, model_options={}, seed=None): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.predict_noise, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True) + ).execute(x, timestep, model_options, seed) def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py new file mode 100644 index 000000000..e2b2efcd9 --- /dev/null +++ b/comfy_extras/nodes_easycache.py @@ -0,0 +1,459 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +from comfy_api.latest import io, ComfyExtension +import comfy.patcher_extension +import logging +import torch +import comfy.model_patcher +if TYPE_CHECKING: + from uuid import UUID + + +def easycache_forward_wrapper(executor, *args, **kwargs): + # get values from args + x: torch.Tensor = args[0] + transformer_options: dict[str] = args[-1] + if not isinstance(transformer_options, dict): + transformer_options = kwargs.get("transformer_options") + if not transformer_options: + transformer_options = args[-2] + easycache: EasyCacheHolder = transformer_options["easycache"] + sigmas = transformer_options["sigmas"] + uuids = transformer_options["uuids"] + if sigmas is not None and easycache.is_past_end_timestep(sigmas): + return executor(*args, **kwargs) + # prepare next x_prev + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(sigmas) + if do_easycache: + # if first cond marked this step for skipping, skip it and use appropriate cached values + if easycache.skip_current_step: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") + return easycache.apply_cache_diff(x, uuids) + if easycache.initial_step: + easycache.first_cond_uuid = uuids[0] + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + easycache.initial_step = False + if has_first_cond_uuid: + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x, uuids) + else: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + + output: torch.Tensor = executor(*args, **kwargs) + if has_first_cond_uuid and easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev, uuids) + if has_first_cond_uuid: + easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) + easycache.output_prev_subsampled = easycache.subsample(output, uuids) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def lazycache_predict_noise_wrapper(executor, *args, **kwargs): + # get values from args + x: torch.Tensor = args[0] + timestep: float = args[1] + model_options: dict[str] = args[2] + easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] + if easycache.is_past_end_timestep(timestep): + return executor(*args, **kwargs) + # prepare next x_prev + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(timestep) + if do_easycache: + if easycache.has_x_prev_subsampled(): + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x) + else: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + output: torch.Tensor = executor(*args, **kwargs) + if easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev) + easycache.x_prev_subsampled = easycache.subsample(next_x_prev) + easycache.output_prev_subsampled = easycache.subsample(output) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs): + model_options = args[-1] + easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"] + easycache.skip_current_step = False + # TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset + return executor(*args, **kwargs) + +def easycache_sample_wrapper(executor, *args, **kwargs): + """ + This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end. + """ + try: + guider = executor.class_obj + orig_model_options = guider.model_options + guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) + # clone and prepare timesteps + guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) + easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache'] + logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}") + return executor(*args, **kwargs) + finally: + easycache = guider.model_options['transformer_options']['easycache'] + output_change_rates = easycache.output_change_rates + approx_output_change_rates = easycache.approx_output_change_rates + if easycache.verbose: + logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") + logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") + total_steps = len(args[3])-1 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).") + easycache.reset() + guider.model_options = orig_model_options + + +class EasyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + self.name = "EasyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + # cache values + self.first_cond_uuid = None + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + # how to deal with mismatched dims + self.allow_mismatch = True + self.cut_from_start = True + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor: + batch_offset = x.shape[0] // len(uuids) + uuid_idx = uuids.index(self.first_cond_uuid) + if self.subsample_factor > 1: + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...] + if clone: + return to_return.clone() + return to_return + + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): + if self.first_cond_uuid in uuids: + self.total_steps_skipped += 1 + batch_offset = x.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_this_dim = True + for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + if skip_this_dim: + skip_this_dim = False + continue + if dim_u != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_u, None)) + else: + slicing.append(slice(None, dim_u)) + else: + slicing.append(slice(None)) + slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing + x = x[slicing] + x += self.uuid_cache_diffs[uuid].to(x.device) + return x + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if output.shape[1:] != x.shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_dim = True + for dim_o, dim_x in zip(output.shape, x.shape): + if not skip_dim and dim_o != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_o, None)) + else: + slicing.append(slice(None, dim_o)) + else: + slicing.append(slice(None)) + skip_dim = False + x = x[slicing] + diff = output - x + batch_offset = diff.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + + def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: + return self.first_cond_uuid in uuids + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + self.output_change_rates = [] + self.first_cond_uuid = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + del self.uuid_cache_diffs + self.uuid_cache_diffs = {} + self.total_steps_skipped = 0 + return self + + def clone(self): + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + + +class EasyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EasyCache", + display_name="EasyCache", + description="Native EasyCache implementation.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add EasyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with EasyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) + return io.NodeOutput(model) + + +class LazyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + self.name = "LazyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + # cache values + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.cache_diff: torch.Tensor = None + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + + def has_cache_diff(self) -> bool: + return self.cache_diff is not None + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor: + if self.subsample_factor > 1: + to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + if clone: + return x.clone() + return x + + def apply_cache_diff(self, x: torch.Tensor): + self.total_steps_skipped += 1 + return x + self.cache_diff.to(x.device) + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): + self.cache_diff = output - x + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.output_change_rates = [] + self.approx_output_change_rates = [] + del self.cache_diff + self.cache_diff = None + self.total_steps_skipped = 0 + return self + + def clone(self): + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + +class LazyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LazyCache", + display_name="LazyCache", + description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add LazyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with LazyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) + return io.NodeOutput(model) + + +class EasyCacheExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EasyCacheNode, + LazyCacheNode, + ] + +def comfy_entrypoint(): + return EasyCacheExtension() diff --git a/nodes.py b/nodes.py index 9681750d3..723ce3384 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,7 +2322,8 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", - "nodes_model_patch.py" + "nodes_model_patch.py", + "nodes_easycache.py", ] import_failed = [] From 41048c69b4ccf63f876213a95a51cdde1cb0ab84 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:15:44 -0700 Subject: [PATCH 105/325] Fix Conditioning masks on 3d latents. (#9506) --- comfy/samplers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index ec7e0b350..c7dfef4ea 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -17,6 +17,7 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.utils import scipy.stats import numpy @@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): if "mask_strength" in conds: mask_strength = conds["mask_strength"] mask = conds['mask'] - assert (mask.shape[1:] == x_in.shape[2:]) + # assert (mask.shape[1:] == x_in.shape[2:]) mask = mask[:input_x.shape[0]] if area is not None: @@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in): mask = mask.narrow(i + 1, area[len(dims) + i], area[i]) mask = mask * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1)) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device): if len(mask.shape) == len(dims): mask = mask.unsqueeze(0) if mask.shape[1:] != dims: - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) + if mask.ndim < 4: + mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) + else: + mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) From 59eddda90030b61f172e155bc1e2526a51a27dff Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 22 Aug 2025 22:36:44 -0700 Subject: [PATCH 106/325] Python 3.13 is well supported. (#9511) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 79a8a8c79..99a50571b 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ comfy install ## Manual Install (Windows, Linux) -python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet. +Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12 Git clone this repo. From 8be0d22ab76a3d548c9c376fd816b39d4c028c12 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 23 Aug 2025 10:56:17 -0700 Subject: [PATCH 107/325] Don't use the annoying new navigation mode by default. (#9518) --- app/app_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/app_settings.py b/app/app_settings.py index c7ac73bf6..eb69133a3 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -25,7 +25,7 @@ class AppSettings(): logging.error(f"The user settings file is corrupted: {file}") return {} else: - return {} + return {"Comfy.Canvas.NavigationMode": "legacy"} def save_settings(self, request, settings): file = self.user_manager.get_request_user_filepath( From 3e316c6338503a535801db3ddac9572a38a607ef Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 23 Aug 2025 14:54:01 -0700 Subject: [PATCH 108/325] Update frontend to v1.25.10 and revert navigation mode override (#9522) - Update comfyui-frontend-package from 1.25.9 to 1.25.10 - Revert forced legacy navigation mode from PR #9518 - Frontend v1.25.10 includes proper navigation mode fixes and improved display text --- app/app_settings.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/app_settings.py b/app/app_settings.py index eb69133a3..c7ac73bf6 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -25,7 +25,7 @@ class AppSettings(): logging.error(f"The user settings file is corrupted: {file}") return {} else: - return {"Comfy.Canvas.NavigationMode": "legacy"} + return {} def save_settings(self, request, settings): file = self.user_manager.get_request_user_filepath( diff --git a/requirements.txt b/requirements.txt index 6b53fabc1..131484ce8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.9 +comfyui-frontend-package==1.25.10 comfyui-workflow-templates==0.1.65 comfyui-embedded-docs==0.2.6 torch From 71ed4a399ec76a75aa2870b772d2022e4b9a69a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Aug 2025 18:57:09 -0400 Subject: [PATCH 109/325] ComfyUI version 0.3.52 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 65f06cf37..834c3e8c2 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.51" +__version__ = "0.3.52" diff --git a/pyproject.toml b/pyproject.toml index ecbf04303..f6e765a81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.51" +version = "0.3.52" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 95ac7794b7c735de8e5426442507d08edd29bec5 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Sun, 24 Aug 2025 13:29:49 -0600 Subject: [PATCH 110/325] Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling (#9528) * Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling * Fix missing LazyCache check_metadata method Ensure LazyCache reset method resets all the tensor state values --- comfy_extras/nodes_easycache.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index e2b2efcd9..9d2988f5f 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): input_change = None do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: + easycache.check_metadata(x) # if first cond marked this step for skipping, skip it and use appropriate cached values if easycache.skip_current_step: if easycache.verbose: @@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): input_change = None do_easycache = easycache.should_do_easycache(timestep) if do_easycache: + easycache.check_metadata(x) if easycache.has_x_prev_subsampled(): if easycache.has_x_prev_subsampled(): input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() @@ -194,6 +196,7 @@ class EasyCacheHolder: # how to deal with mismatched dims self.allow_mismatch = True self.cut_from_start = True + self.state_metadata = None def is_past_end_timestep(self, timestep: float) -> bool: return not (timestep[0] > self.end_t).item() @@ -283,6 +286,17 @@ class EasyCacheHolder: def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: return self.first_cond_uuid in uuids + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape[1:]) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + def reset(self): self.relative_transformation_rate = 0.0 self.cumulative_change_rate = 0.0 @@ -299,6 +313,7 @@ class EasyCacheHolder: del self.uuid_cache_diffs self.uuid_cache_diffs = {} self.total_steps_skipped = 0 + self.state_metadata = None return self def clone(self): @@ -360,6 +375,7 @@ class LazyCacheHolder: self.output_change_rates = [] self.approx_output_change_rates = [] self.total_steps_skipped = 0 + self.state_metadata = None def has_cache_diff(self) -> bool: return self.cache_diff is not None @@ -404,6 +420,17 @@ class LazyCacheHolder: def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): self.cache_diff = output - x + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + def reset(self): self.relative_transformation_rate = 0.0 self.cumulative_change_rate = 0.0 @@ -412,7 +439,14 @@ class LazyCacheHolder: self.approx_output_change_rates = [] del self.cache_diff self.cache_diff = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None self.total_steps_skipped = 0 + self.state_metadata = None return self def clone(self): From f6b93d41a03081fad3c1a01221eac9c42d6790df Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 24 Aug 2025 12:40:32 -0700 Subject: [PATCH 111/325] Remove models from readme that are not fully implemented. (#9535) Cosmos model implementations are currently missing the safety part so it is technically not fully implemented and should not be advertised as such. --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 99a50571b..8024870c2 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) @@ -77,7 +76,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) - Audio Models From e633a47ad1b875e52758be27ec34cb8907ebe1fb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:13:54 -0700 Subject: [PATCH 112/325] Add models/audio_encoders directory. (#9548) --- folder_paths.py | 2 ++ models/audio_encoders/put_audio_encoder_models_here | 0 2 files changed, 2 insertions(+) create mode 100644 models/audio_encoders/put_audio_encoder_models_here diff --git a/folder_paths.py b/folder_paths.py index b34af39e8..f110d832b 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -48,6 +48,8 @@ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers" folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions) +folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/audio_encoders/put_audio_encoder_models_here b/models/audio_encoders/put_audio_encoder_models_here new file mode 100644 index 000000000..e69de29bb From 914c2a29731be9c082f773c4b95892f553ac5ae8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 25 Aug 2025 20:26:47 -0700 Subject: [PATCH 113/325] Implement wav2vec2 as an audio encoder model. (#9549) This is useless on its own but there are multiple models that use it. --- comfy/audio_encoders/audio_encoders.py | 42 +++++ comfy/audio_encoders/wav2vec2.py | 207 +++++++++++++++++++++++++ comfy_api/latest/_io.py | 8 + comfy_extras/nodes_audio_encoder.py | 44 ++++++ nodes.py | 1 + 5 files changed, 302 insertions(+) create mode 100644 comfy/audio_encoders/audio_encoders.py create mode 100644 comfy/audio_encoders/wav2vec2.py create mode 100644 comfy_extras/nodes_audio_encoder.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py new file mode 100644 index 000000000..538c21bd5 --- /dev/null +++ b/comfy/audio_encoders/audio_encoders.py @@ -0,0 +1,42 @@ +from .wav2vec2 import Wav2Vec2Model +import comfy.model_management +import comfy.ops +import comfy.utils +import logging +import torchaudio + + +class AudioEncoderModel(): + def __init__(self, config): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) + self.model.eval() + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.model_sample_rate = 16000 + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False) + + def get_sd(self): + return self.model.state_dict() + + def encode_audio(self, audio, sample_rate): + comfy.model_management.load_model_gpu(self.patcher) + audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) + out, all_layers = self.model(audio.to(self.load_device)) + outputs = {} + outputs["encoded_audio"] = out + outputs["encoded_audio_all_layers"] = all_layers + return outputs + + +def load_audio_encoder_from_sd(sd, prefix=""): + audio_encoder = AudioEncoderModel(None) + sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + m, u = audio_encoder.load_sd(sd) + if len(m) > 0: + logging.warning("missing audio encoder: {}".format(m)) + + return audio_encoder diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py new file mode 100644 index 000000000..de906622a --- /dev/null +++ b/comfy/audio_encoders/wav2vec2.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention_masked + + +class LayerNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) + + +class ConvFeatureEncoder(nn.Module): + def __init__(self, conv_dim, dtype=None, device=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), + ]) + + def forward(self, x): + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x.transpose(1, 2) + + +class FeatureProjection(nn.Module): + def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None): + super().__init__() + self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype) + self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.layer_norm(x) + x = self.projection(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self, embed_dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + self.activation = nn.GELU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv(x)[:, :, :-1] + x = self.activation(x) + x = x.transpose(1, 2) + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + num_layers=12, + mlp_ratio=4.0, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) + self.layers = nn.ModuleList([ + TransformerEncoderLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations + ) + for _ in range(num_layers) + ]) + + self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + + def forward(self, x, mask=None): + x = x + self.pos_conv_embed(x) + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x, mask) + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class Attention(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x, mask=None): + assert (mask is None) # TODO? + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention_masked(q, k, v, self.num_heads) + return self.out_proj(out) + + +class FeedForward(nn.Module): + def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype) + self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.output_dense(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None): + residual = x + x = self.layer_norm(x) + x = self.attention(x, mask=mask) + x = residual + x + + x = x + self.feed_forward(self.final_layer_norm(x)) + return x + + +class Wav2Vec2Model(nn.Module): + """Complete Wav2Vec 2.0 model.""" + + def __init__( + self, + embed_dim=1024, + final_dim=256, + num_heads=16, + num_layers=24, + dtype=None, device=None, operations=None + ): + super().__init__() + + conv_dim = 512 + self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) + self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) + + self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + + self.encoder = TransformerEncoder( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=num_layers, + device=device, dtype=dtype, operations=operations + ) + + def forward(self, x, mask_time_indices=None, return_dict=False): + + x = torch.mean(x, dim=1) + + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + + features = self.feature_extractor(x) + features = self.feature_projection(features) + + batch_size, seq_len, _ = features.shape + + x, all_x = self.encoder(features) + + return x, all_x diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a3a21facc..5cb474459 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO): class MODEL_PATCH(ComfyTypeIO): Type = Any +@comfytype(io_type="AUDIO_ENCODER") +class AUDIO_ENCODER(ComfyTypeIO): + Type = Any + +@comfytype(io_type="AUDIO_ENCODER_OUTPUT") +class AUDIO_ENCODER_OUTPUT(ComfyTypeIO): + Type = Any + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py new file mode 100644 index 000000000..39a140fef --- /dev/null +++ b/comfy_extras/nodes_audio_encoder.py @@ -0,0 +1,44 @@ +import folder_paths +import comfy.audio_encoders.audio_encoders +import comfy.utils + + +class AudioEncoderLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), + }} + RETURN_TYPES = ("AUDIO_ENCODER",) + FUNCTION = "load_model" + + CATEGORY = "loaders" + + def load_model(self, audio_encoder_name): + audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) + sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) + audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) + if audio_encoder is None: + raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") + return (audio_encoder,) + + +class AudioEncoderEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio_encoder": ("AUDIO_ENCODER",), + "audio": ("AUDIO",), + }} + RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) + FUNCTION = "encode" + + CATEGORY = "conditioning" + + def encode(self, audio_encoder, audio): + output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) + return (output,) + + +NODE_CLASS_MAPPINGS = { + "AudioEncoderLoader": AudioEncoderLoader, + "AudioEncoderEncode": AudioEncoderEncode, +} diff --git a/nodes.py b/nodes.py index 723ce3384..0aff6b14a 100644 --- a/nodes.py +++ b/nodes.py @@ -2324,6 +2324,7 @@ async def init_builtin_extra_nodes(): "nodes_qwen.py", "nodes_model_patch.py", "nodes_easycache.py", + "nodes_audio_encoder.py", ] import_failed = [] From 39aa06bd5d630e50c88d3be1586d21737c4387c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 26 Aug 2025 09:50:46 -0700 Subject: [PATCH 114/325] Make AudioEncoderOutput usable in v3 node schema. (#9554) --- comfy_api/latest/_io.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5cb474459..e0ee943a7 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -731,11 +731,11 @@ class MODEL_PATCH(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") -class AUDIO_ENCODER(ComfyTypeIO): +class AudioEncoder(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER_OUTPUT") -class AUDIO_ENCODER_OUTPUT(ComfyTypeIO): +class AudioEncoderOutput(ComfyTypeIO): Type = Any @comfytype(io_type="COMFY_MULTITYPED_V3") @@ -1592,6 +1592,7 @@ class _IO: Model = Model ClipVision = ClipVision ClipVisionOutput = ClipVisionOutput + AudioEncoderOutput = AudioEncoderOutput StyleModel = StyleModel Gligen = Gligen UpscaleModel = UpscaleModel From 5352abc6d389570455776c457738db54367cd6cb Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 27 Aug 2025 01:33:54 +0800 Subject: [PATCH 115/325] Update template to 0.1.66 (#9557) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 131484ce8..db59bb38c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.10 -comfyui-workflow-templates==0.1.65 +comfyui-workflow-templates==0.1.66 comfyui-embedded-docs==0.2.6 torch torchsde From 47f4db3e84874ca6076e5cdbb345444faec83028 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 26 Aug 2025 19:20:44 -0700 Subject: [PATCH 116/325] Adding Google Gemini Image API node (#9566) * bigcat88's progress on adding Google Gemini Image node * Made Google Gemini Image node functional * Bump frontend version to get static pricing badge on Gemini Image node --- comfy_api_nodes/apis/gemini_api.py | 19 ++ comfy_api_nodes/nodes_gemini.py | 388 ++++++++++++++++++++++------- requirements.txt | 2 +- 3 files changed, 314 insertions(+), 95 deletions(-) create mode 100644 comfy_api_nodes/apis/gemini_api.py diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py new file mode 100644 index 000000000..138bf035d --- /dev/null +++ b/comfy_api_nodes/apis/gemini_api.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import List, Optional + +from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata +from pydantic import BaseModel + + +class GeminiImageGenerationConfig(GeminiGenerationConfig): + responseModalities: Optional[List[str]] = None + + +class GeminiImageGenerateContentRequest(BaseModel): + contents: List[GeminiContent] + generationConfig: Optional[GeminiImageGenerationConfig] = None + safetySettings: Optional[List[GeminiSafetySetting]] = None + systemInstruction: Optional[GeminiSystemInstructionContent] = None + tools: Optional[List[GeminiTool]] = None + videoMetadata: Optional[GeminiVideoMetadata] = None diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 78c402a7a..baa379b75 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -4,11 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer """ from __future__ import annotations - import json import time import os import uuid +import base64 +from io import BytesIO from enum import Enum from typing import Optional, Literal @@ -25,6 +26,7 @@ from comfy_api_nodes.apis import ( GeminiPart, GeminiMimeType, ) +from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest from comfy_api_nodes.apis.client import ( ApiEndpoint, HttpMethod, @@ -35,6 +37,7 @@ from comfy_api_nodes.apinode_utils import ( audio_to_base64_string, video_to_base64_string, tensor_to_base64_string, + bytesio_to_image_tensor, ) @@ -53,6 +56,14 @@ class GeminiModel(str, Enum): gemini_2_5_flash = "gemini-2.5-flash" +class GeminiImageModel(str, Enum): + """ + Gemini Image Model Names allowed by comfy-api + """ + + gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" + + def get_gemini_endpoint( model: GeminiModel, ) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: @@ -75,6 +86,135 @@ def get_gemini_endpoint( ) +def get_gemini_image_endpoint( + model: GeminiImageModel, +) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: + """ + Get the API endpoint for a given Gemini model. + + Args: + model: The Gemini model to use, either as enum or string value. + + Returns: + ApiEndpoint configured for the specific Gemini model. + """ + if isinstance(model, str): + model = GeminiImageModel(model) + return ApiEndpoint( + path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", + method=HttpMethod.POST, + request_model=GeminiImageGenerateContentRequest, + response_model=GeminiGenerateContentResponse, + ) + + +def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: + """ + Convert image tensor input to Gemini API compatible parts. + + Args: + image_input: Batch of image tensors from ComfyUI. + + Returns: + List of GeminiPart objects containing the encoded images. + """ + image_parts: list[GeminiPart] = [] + for image_index in range(image_input.shape[0]): + image_as_b64 = tensor_to_base64_string( + image_input[image_index].unsqueeze(0) + ) + image_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.image_png, + data=image_as_b64, + ) + ) + ) + return image_parts + + +def create_text_part(text: str) -> GeminiPart: + """ + Create a text part for the Gemini API request. + + Args: + text: The text content to include in the request. + + Returns: + A GeminiPart object with the text content. + """ + return GeminiPart(text=text) + + +def get_parts_from_response( + response: GeminiGenerateContentResponse +) -> list[GeminiPart]: + """ + Extract all parts from the Gemini API response. + + Args: + response: The API response from Gemini. + + Returns: + List of response parts from the first candidate. + """ + return response.candidates[0].content.parts + + +def get_parts_by_type( + response: GeminiGenerateContentResponse, part_type: Literal["text"] | str +) -> list[GeminiPart]: + """ + Filter response parts by their type. + + Args: + response: The API response from Gemini. + part_type: Type of parts to extract ("text" or a MIME type). + + Returns: + List of response parts matching the requested type. + """ + parts = [] + for part in get_parts_from_response(response): + if part_type == "text" and hasattr(part, "text") and part.text: + parts.append(part) + elif ( + hasattr(part, "inlineData") + and part.inlineData + and part.inlineData.mimeType == part_type + ): + parts.append(part) + # Skip parts that don't match the requested type + return parts + + +def get_text_from_response(response: GeminiGenerateContentResponse) -> str: + """ + Extract and concatenate all text parts from the response. + + Args: + response: The API response from Gemini. + + Returns: + Combined text from all text parts in the response. + """ + parts = get_parts_by_type(response, "text") + return "\n".join([part.text for part in parts]) + + +def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: + image_tensors: list[torch.Tensor] = [] + parts = get_parts_by_type(response, "image/png") + for part in parts: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + image_tensors.append(returned_image) + if len(image_tensors) == 0: + return torch.zeros((1,1024,1024,4)) + return torch.cat(image_tensors, dim=0) + + class GeminiNode(ComfyNodeABC): """ Node to generate text responses from a Gemini model. @@ -159,59 +299,6 @@ class GeminiNode(ComfyNodeABC): CATEGORY = "api node/text/Gemini" API_NODE = True - def get_parts_from_response( - self, response: GeminiGenerateContentResponse - ) -> list[GeminiPart]: - """ - Extract all parts from the Gemini API response. - - Args: - response: The API response from Gemini. - - Returns: - List of response parts from the first candidate. - """ - return response.candidates[0].content.parts - - def get_parts_by_type( - self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str - ) -> list[GeminiPart]: - """ - Filter response parts by their type. - - Args: - response: The API response from Gemini. - part_type: Type of parts to extract ("text" or a MIME type). - - Returns: - List of response parts matching the requested type. - """ - parts = [] - for part in self.get_parts_from_response(response): - if part_type == "text" and hasattr(part, "text") and part.text: - parts.append(part) - elif ( - hasattr(part, "inlineData") - and part.inlineData - and part.inlineData.mimeType == part_type - ): - parts.append(part) - # Skip parts that don't match the requested type - return parts - - def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str: - """ - Extract and concatenate all text parts from the response. - - Args: - response: The API response from Gemini. - - Returns: - Combined text from all text parts in the response. - """ - parts = self.get_parts_by_type(response, "text") - return "\n".join([part.text for part in parts]) - def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: """ Convert video input to Gemini API compatible parts. @@ -271,43 +358,6 @@ class GeminiNode(ComfyNodeABC): ) return audio_parts - def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ - image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string( - image_input[image_index].unsqueeze(0) - ) - image_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.image_png, - data=image_as_b64, - ) - ) - ) - return image_parts - - def create_text_part(self, text: str) -> GeminiPart: - """ - Create a text part for the Gemini API request. - - Args: - text: The text content to include in the request. - - Returns: - A GeminiPart object with the text content. - """ - return GeminiPart(text=text) - async def api_call( self, prompt: str, @@ -323,11 +373,11 @@ class GeminiNode(ComfyNodeABC): validate_string(prompt, strip_whitespace=False) # Create parts list with text prompt as the first part - parts: list[GeminiPart] = [self.create_text_part(prompt)] + parts: list[GeminiPart] = [create_text_part(prompt)] # Add other modal parts if images is not None: - image_parts = self.create_image_parts(images) + image_parts = create_image_parts(images) parts.extend(image_parts) if audio is not None: parts.extend(self.create_audio_parts(audio)) @@ -351,7 +401,7 @@ class GeminiNode(ComfyNodeABC): ).execute() # Get result output - output_text = self.get_text_from_response(response) + output_text = get_text_from_response(response) if unique_id and output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { @@ -462,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC): return (files,) +class GeminiImage(ComfyNodeABC): + """ + Node to generate text and image responses from a Gemini model. + + This node allows users to interact with Google's Gemini AI models, providing + multimodal inputs (text, images, files) to generate coherent + text and image responses. The node works with the latest Gemini models, handling the + API communication and response parsing. + """ + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text prompt for generation", + }, + ), + "model": ( + IO.COMBO, + { + "tooltip": "The Gemini model to use for generating responses.", + "options": [model.value for model in GeminiImageModel], + "default": GeminiImageModel.gemini_2_5_flash_image_preview.value, + }, + ), + "seed": ( + IO.INT, + { + "default": 42, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", + }, + ), + }, + "optional": { + "images": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + }, + ), + "files": ( + "GEMINI_INPUT_FILES", + { + "default": None, + "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", + }, + ), + # TODO: later we can add this parameter later + # "n": ( + # IO.INT, + # { + # "default": 1, + # "min": 1, + # "max": 8, + # "step": 1, + # "display": "number", + # "tooltip": "How many images to generate", + # }, + # ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = (IO.IMAGE, IO.STRING) + FUNCTION = "api_call" + CATEGORY = "api node/image/Gemini" + DESCRIPTION = "Edit images synchronously via Google API." + API_NODE = True + + async def api_call( + self, + prompt: str, + model: GeminiImageModel, + images: Optional[IO.IMAGE] = None, + files: Optional[list[GeminiPart]] = None, + n=1, + unique_id: Optional[str] = None, + **kwargs, + ): + # Validate inputs + validate_string(prompt, strip_whitespace=True, min_length=1) + # Create parts list with text prompt as the first part + parts: list[GeminiPart] = [create_text_part(prompt)] + + # Add other modal parts + if images is not None: + image_parts = create_image_parts(images) + parts.extend(image_parts) + if files is not None: + parts.extend(files) + + response = await SynchronousOperation( + endpoint=get_gemini_image_endpoint(model), + request=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent( + role="user", + parts=parts, + ), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=["TEXT","IMAGE"] + ) + ), + auth_kwargs=kwargs, + ).execute() + + output_image = get_image_from_response(response) + output_text = get_text_from_response(response) + if unique_id and output_text: + # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. + render_spec = { + "node_id": unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) + + output_text = output_text or "Empty response from Gemini model..." + return (output_image, output_text,) + + NODE_CLASS_MAPPINGS = { "GeminiNode": GeminiNode, + "GeminiImageNode": GeminiImage, "GeminiInputFiles": GeminiInputFiles, } NODE_DISPLAY_NAME_MAPPINGS = { "GeminiNode": "Google Gemini", + "GeminiImageNode": "Google Gemini Image", "GeminiInputFiles": "Gemini Input Files", } diff --git a/requirements.txt b/requirements.txt index db59bb38c..174f3d4d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.10 +comfyui-frontend-package==1.25.11 comfyui-workflow-templates==0.1.66 comfyui-embedded-docs==0.2.6 torch From 6a193ac557b2b35a6d2ea1916b0b8d5d9ee9b1ba Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 27 Aug 2025 12:10:20 +0800 Subject: [PATCH 117/325] Update template to 0.1.68 (#9569) * Update template to 0.1.67 * Update template to 0.1.68 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 174f3d4d1..93d88859d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.66 +comfyui-workflow-templates==0.1.68 comfyui-embedded-docs==0.2.6 torch torchsde From 88aee596a30e9b80ca831c42a0ae70e0d22b61ae Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:10:34 -0700 Subject: [PATCH 118/325] WIP Wan 2.2 S2V model. (#9568) --- comfy/ldm/wan/model.py | 508 ++++++++++++++++++++++++++++++++++++-- comfy/model_base.py | 23 ++ comfy/model_detection.py | 2 + comfy/supported_models.py | 15 +- comfy_extras/nodes_wan.py | 175 +++++++++++++ 5 files changed, 707 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1885d9730..dedfb47e2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -4,7 +4,7 @@ import math import torch import torch.nn as nn -from einops import repeat +from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND @@ -153,7 +153,10 @@ def repeat_e(e, x): repeats = x.size(1) // e.size(1) if repeats == 1: return e - return torch.repeat_interleave(e, repeats, dim=1) + if repeats * e.size(1) == x.size(1): + return torch.repeat_interleave(e, repeats, dim=1) + else: + return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)] class WanAttentionBlock(nn.Module): @@ -573,6 +576,28 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder(img_ids).movedim(1, 2) + return freqs + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -584,26 +609,16 @@ class WanModel(torch.nn.Module): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) - + t_len = t if time_dim_concat is not None: time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) x = torch.cat([x, time_dim_concat], dim=2) - t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + t_len = x.shape[2] if self.ref_conv is not None and "reference_latent" in kwargs: t_len += 1 - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - - freqs = self.rope_embedder(img_ids).movedim(1, 2) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): @@ -839,3 +854,466 @@ class CameraWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + + +class CausalConv1d(nn.Module): + + def __init__(self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode='replicate', + operations=None, + **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) + + def forward(self, x): + x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, + in_dim: int, + hidden_dim: int, + num_heads=int, + need_global=True, + dtype=None, + device=None, + operations=None,): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + if need_global: + self.conv1_global = CausalConv1d( + in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs) + + if need_global: + self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm2 = operations.LayerNorm( + hidden_dim // 2, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm3 = operations.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_token=4, + need_global=False, + dtype=None, + device=None, + operations=None): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, + hidden_dim=out_dim, + num_heads=num_token, + need_global=need_global, dtype=dtype, device=device, operations=operations) + weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device) + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum( + dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + + output_dim = output_dim or embedding_dim * 2 + + self.silu = nn.SiLU() + self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device) + self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__(self, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + adain_mode=None, + dtype=None, + device=None, + operations=None): + super().__init__() + self.enable_adain = enable_adain + self.adain_mode = adain_mode + self.injected_block_id = {} + audio_injector_id = 0 + for inject_id in inject_layer: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([ + WanT2VCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype} + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_feat = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_vec = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations) + for _ in range(audio_injector_id) + ]) + if adain_mode != "attn_norm": + self.injector_adain_output_layers = nn.ModuleList( + [operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)]) + + def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len): + audio_attn_id = self.injected_block_id.get(block_id, None) + if audio_attn_id is None: + return x + + num_frames = audio_emb.shape[1] + input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames) + if self.enable_adain and self.adain_mode == "attn_norm": + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + else: + attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb) + residual_out = rearrange( + residual_out, "(b t) n c -> b (t n) c", t=num_frames) + x[:, :seq_len] = x[:, :seq_len] + residual_out + return x + + +class FramePackMotioner(nn.Module): + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, 2, 16 + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + dtype=None, + device=None, + operations=None): + super().__init__() + self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device) + self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device) + self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device) + self.zip_frame_buckets = zip_frame_buckets + + self.inner_dim = inner_dim + self.num_heads = num_heads + + self.drop_mode = drop_mode + + def forward(self, motion_latents, rope_embedder, add_last_motion=2): + lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4] + padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype) + overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2]) + if overlap_frame > 0: + padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1]) + padd_lat[:, :, -zero_end_frame:] = 0 + + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x) + l_2x_shape = clean_latents_2x.shape + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x) + l_4x_shape = clean_latents_4x.shape + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, : + 0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, : + 0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype) + rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + + rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1) + return motion_lat, rope + + +class WanModel_S2V(WanModel): + def __init__(self, + model_type='s2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + audio_dim=1024, + num_audio_token=4, + enable_adain=True, + cond_dim=16, + audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + adain_mode="attn_norm", + framepack_drop_mode="padd", + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) + + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_dim, + out_dim=self.dim, + num_token=num_audio_token, + need_global=enable_adain, dtype=dtype, device=device, operations=operations) + + if cond_dim > 0: + self.cond_encoder = operations.Conv3d( + cond_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size, device=device, dtype=dtype) + + self.audio_injector = AudioInjector_WAN( + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + adain_mode=adain_mode, + dtype=dtype, device=device, operations=operations + ) + + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode, + dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + audio_embed=None, + reference_latent=None, + control_video=None, + reference_motion=None, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + if audio_embed is not None: + num_embeds = x.shape[-3] * 4 + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds]) + else: + audio_emb = None + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + if control_video is not None: + x = x + self.cond_encoder(control_video) + + if t.ndim == 1: + t = t.unsqueeze(1).repeat(1, x.shape[2]) + + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + seq_len = x.size(1) + + cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1) + x = x + cond_mask_weight[0] + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype) + ref = ref + cond_mask_weight[1] + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) + + if reference_motion is not None: + motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) + motion_encoded = motion_encoded + cond_mask_weight[2] + x = torch.cat([x, motion_encoded], dim=1) + freqs = torch.cat([freqs, freqs_motion], dim=1) + + t = torch.repeat_interleave(t, 2, dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context) + if audio_emb is not None: + x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 6c861b15e..18d55c1c4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1201,6 +1201,29 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN22_S2V(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion)) + + control_video = kwargs.get("control_video", None) + if control_video is not None: + out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) + return out + class WAN22(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0caff53e0..9f3ab64df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -368,6 +368,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "camera" else: dit_config["model_type"] = "camera_2.2" + elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "s2v" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7ed6dfd69..ce571e6cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1072,6 +1072,19 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN22_S2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "s2v", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_S2V(self, device=device) + return out + class WAN22_T2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1272,6 +1285,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0fff02f76..89ff74d85 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -786,6 +786,180 @@ class WanTrackToVideo(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent) +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = torch.nn.functional.interpolate( + features, size=output_len, align_corners=True, + mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +def get_sample_indices(original_fps, + total_frames, + target_fps, + num_sample, + fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0) + batch_audio_eb = [] + audio_sample_stride = int(video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list( + range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + +class WanSoundImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + io.Image.Input("ref_motion", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + if audio_encoder_output is not None: + feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) + video_rate = 30 + fps = 16 + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket}) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if ref_motion is not None: + if ref_motion.shape[0] > 73: + ref_motion = ref_motion[-73:] + + ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + if ref_motion.shape[0] < 73: + r = torch.ones([73, height, width, 3]) * 0.5 + r[-ref_motion.shape[0]:] = ref_motion + ref_motion = r + + ref_motion = vae.encode(ref_motion[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion}) + negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion}) + + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + control_video = vae.encode(control_video[:, :, :, :3]) + control_video_out[:, :, :control_video.shape[2]] = control_video + + # TODO: check if zero is better than none if none provided + positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) + negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -844,6 +1018,7 @@ class WanExtension(ComfyExtension): TrimVideoLatent, WanCameraImageToVideo, WanPhantomSubjectToVideo, + WanSoundImageToVideo, Wan22ImageToVideoLatent, ] From 31a37686d02aeaba8ea827933832be7601b31fac Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:44:29 -0700 Subject: [PATCH 119/325] Negative audio in s2v should be zeros. (#9578) --- comfy_extras/nodes_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 89ff74d85..312260f00 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -920,7 +920,7 @@ class WanSoundImageToVideo(io.ComfyNode): audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) - negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) if ref_image is not None: ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) From b20ba1f27cbd4e1c84cf8ec72b345723de9e7c80 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:45:02 +0800 Subject: [PATCH 120/325] Fix #9537 (#9576) --- comfy/weight_adapter/lokr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 49b0be55f..563c835f5 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase): (mat1, mat2, alpha, None, None, None, None, None, None) ) + def to_train(self): + return LokrDiff(self.weights) + @classmethod def load( cls, From b5ac6ed7ce73294e0025ffe3b16452d8434b83c7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:26:28 -0700 Subject: [PATCH 121/325] Fixes to make controlnet type models work on qwen edit and kontext. (#9581) --- comfy/ldm/flux/model.py | 4 ++-- comfy/ldm/qwen_image/model.py | 2 +- comfy_extras/nodes_model_patch.py | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 0a77fa097..1344c3a57 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -158,7 +158,7 @@ class Flux(nn.Module): if i < len(control_i): add = control_i[i] if add is not None: - img += add + img[:, :add.shape[1]] += add if img.dtype == torch.float16: img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) @@ -189,7 +189,7 @@ class Flux(nn.Module): if i < len(control_o): add = control_o[i] if add is not None: - img[:, txt.shape[1] :, ...] += add + img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add img = img[:, txt.shape[1] :, ...] diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 57a458210..04071f31c 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -459,7 +459,7 @@ class QwenImageTransformer2DModel(nn.Module): if i < len(control_i): add = control_i[i] if add is not None: - hidden_states += add + hidden_states[:, :add.shape[1]] += add hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 3eaada9bc..32c40ced3 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -89,6 +89,7 @@ class DiffSynthCnetPatch: self.strength = strength self.mask = mask self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) + self.encoded_image_size = (image.shape[1], image.shape[2]) def encode_latent_cond(self, image): latent_image = self.vae.encode(image) @@ -106,14 +107,15 @@ class DiffSynthCnetPatch: x = kwargs.get("x") img = kwargs.get("img") block_index = kwargs.get("block_index") - if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]: - spacial_compression = self.vae.spacial_compression_encode() + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression): image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") loaded_models = comfy.model_management.loaded_models(only_currently_used=True) self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) comfy.model_management.load_models_gpu(loaded_models) - img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) + img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength) kwargs['img'] = img return kwargs From 496888fd68813033c260195bf70e4d11181e5454 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 13:06:40 -0700 Subject: [PATCH 122/325] Improve s2v performance when generating videos longer than 120 frames. (#9582) --- comfy/ldm/wan/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index dedfb47e2..e70446c86 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel): audio_emb = None # embeddings + bs, _, time, height, width = x.shape x = self.patch_embedding(x.float()).to(x.dtype) if control_video is not None: x = x + self.cond_encoder(control_video) @@ -1272,7 +1273,7 @@ class WanModel_S2V(WanModel): if reference_latent is not None: ref = self.patch_embedding(reference_latent.float()).to(x.dtype) ref = ref.flatten(2).transpose(1, 2) - freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype) ref = ref + cond_mask_weight[1] x = torch.cat([x, ref], dim=1) freqs = torch.cat([freqs, freqs_ref], dim=1) @@ -1296,7 +1297,6 @@ class WanModel_S2V(WanModel): # context context = self.text_embedding(context) - patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): From 491755325cc189d0aa1513b12fac738c87e38de6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 16:02:42 -0700 Subject: [PATCH 123/325] Better s2v memory estimation. (#9584) --- comfy/ldm/wan/model.py | 2 ++ comfy/model_base.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e70446c86..47857dc2b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1278,6 +1278,7 @@ class WanModel_S2V(WanModel): x = torch.cat([x, ref], dim=1) freqs = torch.cat([freqs, freqs_ref], dim=1) t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) + del ref, freqs_ref if reference_motion is not None: motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) @@ -1287,6 +1288,7 @@ class WanModel_S2V(WanModel): t = torch.repeat_interleave(t, 2, dim=1) t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) + del motion_encoded, freqs_motion # time embeddings e = self.time_embedding( diff --git a/comfy/model_base.py b/comfy/model_base.py index 18d55c1c4..ce29fdc49 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module): logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor_conds = () + self.memory_usage_shape_process = {} def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module): input_shapes = [input_shape] for c in self.memory_usage_factor_conds: shape = cond_shapes.get(c, None) - if shape is not None and len(shape) > 0: - input_shapes += shape + if shape is not None: + if c in self.memory_usage_shape_process: + out = [] + for s in shape: + out.append(self.memory_usage_shape_process[c](s)) + shape = out + + if len(shape) > 0: + input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() @@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21): class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + self.memory_usage_factor_conds = ("reference_latent", "reference_motion") + self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21): out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = reference_motion.shape + return out + class WAN22(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) From 3aad339b63f03e17dc6ebae035b90afc2fefb627 Mon Sep 17 00:00:00 2001 From: Gangin Park Date: Thu, 28 Aug 2025 08:07:31 +0900 Subject: [PATCH 124/325] Add DPM++ 2M SDE Heun (RES) sampler (#9542) --- comfy/k_diffusion/sampling.py | 15 +++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) mode change 100644 => 100755 comfy/samplers.py diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a2bc492fd..fe6844b17 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl return x +@torch.no_grad() +def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: diff --git a/comfy/samplers.py b/comfy/samplers.py old mode 100644 new mode 100755 index c7dfef4ea..b3202cec6 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -729,7 +729,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", + "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] From 38f697d953c3989db67e543795768bf954ae0231 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 19:28:10 -0700 Subject: [PATCH 125/325] Add a LatentConcat node. (#9587) --- comfy_extras/nodes_latent.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index f33ed1bee..247d886a1 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -105,6 +105,38 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentConcat: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, dim): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0]) + + if "-" in dim: + c = (s2, s1) + else: + c = (s1, s2) + + if "x" in dim: + dim = -1 + elif "y" in dim: + dim = -2 + elif "t" in dim: + dim = -3 + + samples_out["samples"] = torch.cat(c, dim=dim) + return (samples_out,) + class LatentBatch: @classmethod def INPUT_TYPES(s): @@ -279,6 +311,7 @@ NODE_CLASS_MAPPINGS = { "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentConcat": LatentConcat, "LatentBatch": LatentBatch, "LatentBatchSeedBehavior": LatentBatchSeedBehavior, "LatentApplyOperation": LatentApplyOperation, From 4aa79dbf2c5118853659fc7f7f8590594ab72417 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:08:17 -0700 Subject: [PATCH 126/325] Adjust flux mem usage factor a bit. (#9588) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ce571e6cb..76260de00 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux - memory_usage_factor = 2.8 + memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows. supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] From 0eb821a7b6612af0fa3aaa8302739788a4bd629e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Aug 2025 23:09:06 -0400 Subject: [PATCH 127/325] ComfyUI 0.3.53 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 834c3e8c2..d6fdc47fe 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.52" +__version__ = "0.3.53" diff --git a/pyproject.toml b/pyproject.toml index f6e765a81..a71ad2bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.52" +version = "0.3.53" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From ce0052c087cb1e81ba01e8afbe362bec54eeb665 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 07:37:42 -0700 Subject: [PATCH 128/325] Fix diffsynth controlnet regression. (#9597) --- comfy_extras/nodes_model_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 32c40ced3..65e766b52 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -108,7 +108,7 @@ class DiffSynthCnetPatch: img = kwargs.get("img") block_index = kwargs.get("block_index") spacial_compression = self.vae.spacial_compression_encode() - if self.encoded_image is None or self.encoded_image_size != (x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression): + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") loaded_models = comfy.model_management.loaded_models(only_currently_used=True) self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) From 00636101771cb373354d6294cc6567deda2635f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Aug 2025 10:44:57 -0400 Subject: [PATCH 129/325] ComfyUI version 0.3.54 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index d6fdc47fe..7034953fd 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.53" +__version__ = "0.3.54" diff --git a/pyproject.toml b/pyproject.toml index a71ad2bbf..9f9ac1e21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.53" +version = "0.3.54" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From edde0b50431e296f61f79205e25cb01f653013a2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 14:59:48 -0700 Subject: [PATCH 130/325] WanSoundImageToVideoExtend node to manually extend s2v video. (#9606) --- comfy_extras/nodes_wan.py | 145 +++++++++++++++++++++++++------------- 1 file changed, 97 insertions(+), 48 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 312260f00..0a55bd5d0 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -877,6 +877,67 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_ return batch_audio_eb, min_batch_num +def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None): + latent_t = ((length - 1) // 4) + 1 + if audio_encoder_output is not None: + feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) + video_rate = 30 + fps = 16 + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + batch_frames = latent_t * 4 + audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + + audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames] + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) + frame_offset += batch_frames + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if ref_motion is not None: + if ref_motion.shape[0] > 73: + ref_motion = ref_motion[-73:] + + ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + if ref_motion.shape[0] < 73: + r = torch.ones([73, height, width, 3]) * 0.5 + r[-ref_motion.shape[0]:] = ref_motion + ref_motion = r + + ref_motion_latent = vae.encode(ref_motion[:, :, :, :3]) + + if ref_motion_latent is not None: + ref_motion_latent = ref_motion_latent[:, :, -19:] + positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent}) + negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent}) + + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + control_video = vae.encode(control_video[:, :, :, :3]) + control_video_out[:, :, :control_video.shape[2]] = control_video + + # TODO: check if zero is better than none if none provided + positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) + negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) + + out_latent = {} + out_latent["samples"] = latent + return positive, negative, out_latent, frame_offset + + class WanSoundImageToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -906,57 +967,44 @@ class WanSoundImageToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: - latent_t = ((length - 1) // 4) + 1 - if audio_encoder_output is not None: - feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) - video_rate = 30 - fps = 16 - feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) - audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate) - audio_embed_bucket = audio_embed_bucket.unsqueeze(0) - if len(audio_embed_bucket.shape) == 3: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) - elif len(audio_embed_bucket.shape) == 4: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=ref_motion) + return io.NodeOutput(positive, negative, out_latent) - positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) - negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) - if ref_image is not None: - ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_latent = vae.encode(ref_image[:, :, :, :3]) - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) +class WanSoundImageToVideoExtend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideoExtend", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Latent.Input("video_latent"), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) - if ref_motion is not None: - if ref_motion.shape[0] > 73: - ref_motion = ref_motion[-73:] - - ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - - if ref_motion.shape[0] < 73: - r = torch.ones([73, height, width, 3]) * 0.5 - r[-ref_motion.shape[0]:] = ref_motion - ref_motion = r - - ref_motion = vae.encode(ref_motion[:, :, :, :3]) - positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion}) - negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion}) - - latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - - control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) - if control_video is not None: - control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - control_video = vae.encode(control_video[:, :, :, :3]) - control_video_out[:, :, :control_video.shape[2]] = control_video - - # TODO: check if zero is better than none if none provided - positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) - negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) - - out_latent = {} - out_latent["samples"] = latent + @classmethod + def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput: + video_latent = video_latent["samples"] + width = video_latent.shape[-1] * 8 + height = video_latent.shape[-2] * 8 + batch_size = video_latent.shape[0] + frame_offset = video_latent.shape[-3] * 4 + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=None, ref_motion_latent=video_latent) return io.NodeOutput(positive, negative, out_latent) @@ -1019,6 +1067,7 @@ class WanExtension(ComfyExtension): WanCameraImageToVideo, WanPhantomSubjectToVideo, WanSoundImageToVideo, + WanSoundImageToVideoExtend, Wan22ImageToVideoLatent, ] From 1c184c29eb2a8f6fdd4e49f27347809090038e3f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 15:34:01 -0700 Subject: [PATCH 131/325] Fix issue with s2v node when extending past audio length. (#9608) --- comfy_extras/nodes_wan.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0a55bd5d0..2cbc93ceb 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -893,9 +893,10 @@ def wan_sound_to_video(positive, negative, vae, width, height, length, batch_siz audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames] - positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) - negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) - frame_offset += batch_frames + if audio_embed_bucket.shape[3] > 0: + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) + frame_offset += batch_frames if ref_image is not None: ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) From d28b39d93dc498110e28ca32c8f39e6de631aa42 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 16:38:28 -0700 Subject: [PATCH 132/325] Add a LatentCut node to cut latents. (#9609) --- comfy_extras/nodes_latent.py | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 247d886a1..0f90cf60c 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,6 +1,7 @@ import comfy.utils import comfy_extras.nodes_post_processing import torch +import nodes def reshape_latent_to(target_shape, latent, repeat_batch=True): @@ -137,6 +138,41 @@ class LatentConcat: samples_out["samples"] = torch.cat(c, dim=dim) return (samples_out,) +class LatentCut: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT",), + "dim": (["x", "y", "t"], ), + "index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), + "amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, dim, index, amount): + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if index >= 0: + index = min(index, s1.shape[dim] - 1) + amount = min(s1.shape[dim] - index, amount) + else: + index = max(index, -s1.shape[dim]) + amount = min(-index, amount) + + samples_out["samples"] = torch.narrow(s1, dim, index, amount) + return (samples_out,) + class LatentBatch: @classmethod def INPUT_TYPES(s): @@ -312,6 +348,7 @@ NODE_CLASS_MAPPINGS = { "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentConcat": LatentConcat, + "LatentCut": LatentCut, "LatentBatch": LatentBatch, "LatentBatchSeedBehavior": LatentBatchSeedBehavior, "LatentApplyOperation": LatentApplyOperation, From e80a14ad5073d9eba175c2d2c768a5ca8e4c63ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:13:07 -0700 Subject: [PATCH 133/325] Support wan2.2 5B fun control model. (#9611) Use the Wan22FunControlToVideo node. --- comfy/model_base.py | 15 ++++++--------- comfy_extras/nodes_wan.py | 19 ++++++++++++------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ce29fdc49..56a6798be 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,9 +1110,10 @@ class WAN21(BaseModel): shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: + latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - for i in range(0, image.shape[1], 16): - image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) if extra_channels != image.shape[1] + 4: @@ -1245,18 +1246,14 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out -class WAN22(BaseModel): +class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - cross_attn = kwargs.get("cross_attn", None) - if cross_attn is not None: - out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - - denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + denoise_mask = kwargs.get("denoise_mask", None) if denoise_mask is not None: out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) return out diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2cbc93ceb..8c1d36613 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -139,16 +139,21 @@ class Wan22FunControlToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + if latent_channels == 48: + concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent) + else: + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) - concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] mask[:, :, :start_image.shape[0] + 3] = 0.0 ref_latent = None @@ -159,11 +164,11 @@ class Wan22FunControlToVideo(io.ComfyNode): if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(control_video[:, :, :, :3]) - concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) if ref_latent is not None: positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) From c7bb3e2bceaad7accd52c23d22b97a1b6808304b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:46:57 -0700 Subject: [PATCH 134/325] Support the 5B fun inpaint model. (#9614) Use the WanFunInpaintToVideo node without the clip_vision_output. --- comfy_extras/nodes_wan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8c1d36613..4f73369f5 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -206,7 +206,8 @@ class WanFirstLastFrameToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + spacial_scale = vae.spacial_compression_encode() + latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) if end_image is not None: From 15aa9222c4d1fc74f5190d7c7e56ef986d0d7146 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 29 Aug 2025 01:12:00 -0700 Subject: [PATCH 135/325] Trim audio to video when saving video. (#9617) --- comfy_api/latest/_input_impl/video_types.py | 34 ++++++--------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 28de9651d..f646504c8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -8,6 +8,7 @@ import av import io import json import numpy as np +import math import torch from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents @@ -282,8 +283,6 @@ class VideoFromComponents(VideoInput): if self.__components.audio: audio_sample_rate = int(self.__components.audio['sample_rate']) audio_stream = output.add_stream('aac', rate=audio_sample_rate) - audio_stream.sample_rate = audio_sample_rate - audio_stream.format = 'fltp' # Encode video for i, frame in enumerate(self.__components.images): @@ -298,27 +297,12 @@ class VideoFromComponents(VideoInput): output.mux(packet) if audio_stream and self.__components.audio: - # Encode audio - samples_per_frame = int(audio_sample_rate / frame_rate) - num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame - for i in range(num_frames): - start = i * samples_per_frame - end = start + samples_per_frame - # TODO(Feature) - Add support for stereo audio - chunk = ( - self.__components.audio["waveform"][0, 0, start:end] - .unsqueeze(0) - .contiguous() - .numpy() - ) - audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') - audio_frame.sample_rate = audio_sample_rate - audio_frame.pts = i * samples_per_frame - for packet in audio_stream.encode(audio_frame): - output.mux(packet) - - # Flush audio - for packet in audio_stream.encode(None): - output.mux(packet) - + waveform = self.__components.audio['waveform'] + waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame.sample_rate = audio_sample_rate + frame.pts = 0 + output.mux(audio_stream.encode(frame)) + # Flush encoder + output.mux(audio_stream.encode(None)) From 2efb2cbc38714074b0a48a9f4d70fa43f41499f4 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 29 Aug 2025 18:03:25 +0800 Subject: [PATCH 136/325] Update template to 0.1.70 (#9620) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 93d88859d..7f64aacca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.68 +comfyui-workflow-templates==0.1.70 comfyui-embedded-docs==0.2.6 torch torchsde From a86aaa430183068e2a264495c802c81d05eb350a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 29 Aug 2025 05:33:29 -0400 Subject: [PATCH 137/325] ComfyUI v0.3.55 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 7034953fd..36777e285 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.54" +__version__ = "0.3.55" diff --git a/pyproject.toml b/pyproject.toml index 9f9ac1e21..04514b4a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.54" +version = "0.3.55" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 885015eecf649d6e49e1ade68e4475b434517b82 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 29 Aug 2025 20:06:04 -0700 Subject: [PATCH 138/325] Lower ram usage on windows. (#9628) --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 9b2a33011..b23d50816 100644 --- a/main.py +++ b/main.py @@ -112,6 +112,7 @@ import gc if os.name == "nt": + os.environ['MIMALLOC_PURGE_DELAY'] = '0' logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": From 4449e147692366ac8b9bd3b8834c771bc81e91ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 30 Aug 2025 06:31:19 -0400 Subject: [PATCH 139/325] ComfyUI version 0.3.56 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 36777e285..e8e039373 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.55" +__version__ = "0.3.56" diff --git a/pyproject.toml b/pyproject.toml index 04514b4a8..cfd5d45ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.55" +version = "0.3.56" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f949094b3cbc33779dbf8d3fd140028f8044d5c1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 31 Aug 2025 06:19:21 +0300 Subject: [PATCH 140/325] convert Stable Cascade nodes to V3 schema (#9373) --- comfy_extras/nodes_stable_cascade.py | 165 +++++++++++++++------------ 1 file changed, 93 insertions(+), 72 deletions(-) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 003403215..04c0b366a 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -17,55 +17,61 @@ """ import torch -import nodes +from typing_extensions import override + import comfy.utils +import nodes +from comfy_api.latest import ComfyExtension, io -class StableCascade_EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device +class StableCascade_EmptyLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_EmptyLatentImage", + category="latent/stable_cascade", + inputs=[ + io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, width, height, compression, batch_size=1): + def execute(cls, width, height, compression, batch_size=1): c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageC_VAEEncode: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_StageC_VAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageC_VAEEncode", + category="latent/stable_cascade", + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, image, vae, compression): + def execute(cls, image, vae, compression): width = image.shape[-2] height = image.shape[-3] out_width = (width // compression) * vae.downscale_ratio @@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode: c_latent = vae.encode(s[:,:,:,:3]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageB_Conditioning: + +class StableCascade_StageB_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "conditioning": ("CONDITIONING",), - "stage_c": ("LATENT",), - }} - RETURN_TYPES = ("CONDITIONING",) + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageB_Conditioning", + category="conditioning/stable_cascade", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("stage_c"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - FUNCTION = "set_prior" - - CATEGORY = "conditioning/stable_cascade" - - def set_prior(self, conditioning, stage_c): + @classmethod + def execute(cls, conditioning, stage_c): c = [] for t in conditioning: d = t[1].copy() - d['stable_cascade_prior'] = stage_c['samples'] + d["stable_cascade_prior"] = stage_c["samples"] n = [t[0], d] c.append(n) - return (c, ) + return io.NodeOutput(c) -class StableCascade_SuperResolutionControlnet: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_SuperResolutionControlnet(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_SuperResolutionControlnet", + category="_for_testing/stable_cascade", + is_experimental=True, + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + ], + outputs=[ + io.Image.Output(display_name="controlnet_input"), + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - }} - RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") - RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") - FUNCTION = "generate" - - EXPERIMENTAL = True - CATEGORY = "_for_testing/stable_cascade" - - def generate(self, image, vae): + def execute(cls, image, vae): width = image.shape[-2] height = image.shape[-3] batch_size = image.shape[0] @@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet: c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) - return (controlnet_input, { + return io.NodeOutput(controlnet_input, { "samples": c_latent, }, { "samples": b_latent, }) -NODE_CLASS_MAPPINGS = { - "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, - "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, - "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, - "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, -} + +class StableCascadeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableCascade_EmptyLatentImage, + StableCascade_StageB_Conditioning, + StableCascade_StageC_VAEEncode, + StableCascade_SuperResolutionControlnet, + ] + +async def comfy_entrypoint() -> StableCascadeExtension: + return StableCascadeExtension() From fea9ea8268d9fc0f4245f3fdc4a417ab802033e9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 31 Aug 2025 06:19:54 +0300 Subject: [PATCH 141/325] convert Video nodes to V3 schema (#9489) --- comfy_extras/nodes_video.py | 286 +++++++++++++++++------------------- 1 file changed, 132 insertions(+), 154 deletions(-) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 969f888b9..69fabb12e 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,52 +5,49 @@ import av import torch import folder_paths import json -from typing import Optional, Literal +from typing import Optional +from typing_extensions import override from fractions import Fraction -from comfy.comfy_types import IO, FileLocator, ComfyNodeABC -from comfy_api.latest import Input, InputImpl, Types +from comfy_api.input import AudioInput, ImageInput, VideoInput +from comfy_api.input_impl import VideoFromComponents, VideoFromFile +from comfy_api.util import VideoCodec, VideoComponents, VideoContainer +from comfy_api.latest import ComfyExtension, io, ui from comfy.cli_args import args -class SaveWEBM: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveWEBM(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveWEBM", + category="image/video", + is_experimental=True, + inputs=[ + io.Image.Input("images"), + io.String.Input("filename_prefix", default="ComfyUI"), + io.Combo.Input("codec", options=["vp9", "av1"]), + io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), + io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "codec": (["vp9", "av1"],), - "fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - - EXPERIMENTAL = True - - def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] + ) file = f"{filename}_{counter:05}_.webm" container = av.open(os.path.join(full_output_folder, file), mode="w") - if prompt is not None: - container.metadata["prompt"] = json.dumps(prompt) + if cls.hidden.prompt is not None: + container.metadata["prompt"] = json.dumps(cls.hidden.prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - container.metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) @@ -69,63 +66,46 @@ class SaveWEBM: container.mux(stream.encode()) container.close() - results: list[FileLocator] = [{ - "filename": file, - "subfolder": subfolder, - "type": self.type - }] + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side - -class SaveVideo(ComfyNodeABC): - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type: Literal["output"] = "output" - self.prefix_append = "" +class SaveVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideo", + display_name="Save Video", + category="image/video", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + io.Video.Input("video", tooltip="The video to save."), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to save."}), - "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), - "format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), - "codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - }, - } - - RETURN_TYPES = () - FUNCTION = "save_video" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - DESCRIPTION = "Saves the input images to your ComfyUI output directory." - - def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append + def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, - self.output_dir, + folder_paths.get_output_directory(), width, height ) - results: list[FileLocator] = list() saved_metadata = None if not args.disable_metadata: metadata = {} - if extra_pnginfo is not None: - metadata.update(extra_pnginfo) - if prompt is not None: - metadata["prompt"] = prompt + if cls.hidden.extra_pnginfo is not None: + metadata.update(cls.hidden.extra_pnginfo) + if cls.hidden.prompt is not None: + metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), format=format, @@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC): metadata=saved_metadata ) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return { "ui": { "images": results, "animated": (True,) } } -class CreateVideo(ComfyNodeABC): +class CreateVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "images": (IO.IMAGE, {"tooltip": "The images to create a video from."}), - "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}), - }, - "optional": { - "audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CreateVideo", + display_name="Create Video", + category="image/video", + description="Create a video from images.", + inputs=[ + io.Image.Input("images", tooltip="The images to create a video from."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + ], + outputs=[ + io.Video.Output(), + ], + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "create_video" - - CATEGORY = "image/video" - DESCRIPTION = "Create a video from images." - - def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None): - return (InputImpl.VideoFromComponents( - Types.VideoComponents( - images=images, - audio=audio, - frame_rate=Fraction(fps), - ) - ),) - -class GetVideoComponents(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to extract components from."}), - } - } - RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT) - RETURN_NAMES = ("images", "audio", "fps") - FUNCTION = "get_components" + def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + return io.NodeOutput( + VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + ) - CATEGORY = "image/video" - DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." +class GetVideoComponents(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GetVideoComponents", + display_name="Get Video Components", + category="image/video", + description="Extracts all components from a video: frames, audio, and framerate.", + inputs=[ + io.Video.Input("video", tooltip="The video to extract components from."), + ], + outputs=[ + io.Image.Output(display_name="images"), + io.Audio.Output(display_name="audio"), + io.Float.Output(display_name="fps"), + ], + ) - def get_components(self, video: Input.Video): + @classmethod + def execute(cls, video: VideoInput) -> io.NodeOutput: components = video.get_components() - return (components.images, components.audio, float(components.frame_rate)) + return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) -class LoadVideo(ComfyNodeABC): +class LoadVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = folder_paths.filter_files_content_types(files, ["video"]) - return {"required": - {"file": (sorted(files), {"video_upload": True})}, - } - - CATEGORY = "image/video" - - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "load_video" - def load_video(self, file): - video_path = folder_paths.get_annotated_filepath(file) - return (InputImpl.VideoFromFile(video_path),) + return io.Schema( + node_id="LoadVideo", + display_name="Load Video", + category="image/video", + inputs=[ + io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), + ], + outputs=[ + io.Video.Output(), + ], + ) @classmethod - def IS_CHANGED(cls, file): + def execute(cls, file) -> io.NodeOutput: + video_path = folder_paths.get_annotated_filepath(file) + return io.NodeOutput(VideoFromFile(video_path)) + + @classmethod + def fingerprint_inputs(s, file): video_path = folder_paths.get_annotated_filepath(file) mod_time = os.path.getmtime(video_path) # Instead of hashing the file, we can just use the modification time to avoid @@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC): return mod_time @classmethod - def VALIDATE_INPUTS(cls, file): + def validate_inputs(s, file): if not folder_paths.exists_annotated_filepath(file): return "Invalid video file: {}".format(file) return True -NODE_CLASS_MAPPINGS = { - "SaveWEBM": SaveWEBM, - "SaveVideo": SaveVideo, - "CreateVideo": CreateVideo, - "GetVideoComponents": GetVideoComponents, - "LoadVideo": LoadVideo, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SaveVideo": "Save Video", - "CreateVideo": "Create Video", - "GetVideoComponents": "Get Video Components", - "LoadVideo": "Load Video", -} +class VideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SaveWEBM, + SaveVideo, + CreateVideo, + GetVideoComponents, + LoadVideo, + ] +async def comfy_entrypoint() -> VideoExtension: + return VideoExtension() From d2c502e629ba948029abc13ef1b456b9f4bbbdaa Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 31 Aug 2025 06:20:17 +0300 Subject: [PATCH 142/325] convert nodes_stability.py to V3 schema (#9497) --- comfy_api_nodes/nodes_stability.py | 678 ++++++++++++++++------------- 1 file changed, 365 insertions(+), 313 deletions(-) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 31309d831..e05cb6bb2 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -1,5 +1,8 @@ from inspect import cleandoc -from comfy.comfy_types.node_typing import IO +from typing import Optional +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api_nodes.apis.stability_api import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, @@ -46,87 +49,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse): return StabilityPollStatus.in_progress -class StabilityStableImageUltraNode: +class StabilityStableImageUltraNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + - "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityStableImageUltraNode", + display_name="Stability AI Stable Image Ultra", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + "elements, colors, and subjects will lead to better results. " + "To control the weight of a given word use the format `(word:weight)`," + "where `word` is the word you'd like to control the weight of and `weight`" + "is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" + - "would convey a sky that was blue and green, but more green than blue." - }, + "would convey a sky that was blue and green, but more green than blue.", ), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=[x.value for x in StabilityAspectRatio], + default=StabilityAspectRatio.ratio_1_1.value, + tooltip="Aspect ratio of generated image.", ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature." - }, + comfy_io.Image.Input( + "image", + optional=True, ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + comfy_io.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -144,6 +154,11 @@ class StabilityStableImageUltraNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/generate/ultra", @@ -161,7 +176,7 @@ class StabilityStableImageUltraNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -171,95 +186,106 @@ class StabilityStableImageUltraNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityStableImageSD_3_5Node: +class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityStableImageSD_3_5Node", + display_name="Stability AI Stable Diffusion 3.5 Image", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Combo.Input( + "model", + options=[x.value for x in Stability_SD3_5_Model], + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[x.value for x in StabilityAspectRatio], + default=StabilityAspectRatio.ratio_1_1.value, + tooltip="Aspect ratio of generated image.", + ), + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + comfy_io.Float.Input( + "cfg_scale", + default=4.0, + min=1.0, + max=10.0, + step=0.1, + tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.Image.Input( + "image", + optional=True, + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + comfy_io.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "model": ([x.value for x in Stability_SD3_5_Model],), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "cfg_scale": ( - IO.FLOAT, - { - "default": 4.0, - "min": 1.0, - "max": 10.0, - "step": 0.1, - "tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + cfg_scale: float, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -280,6 +306,11 @@ class StabilityStableImageSD_3_5Node: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/generate/sd3", @@ -300,7 +331,7 @@ class StabilityStableImageSD_3_5Node: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -310,72 +341,75 @@ class StabilityStableImageSD_3_5Node: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleConservativeNode: +class StabilityUpscaleConservativeNode(comfy_io.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleConservativeNode", + display_name="Stability AI Upscale Conservative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Float.Input( + "creativity", + default=0.35, + min=0.2, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.35, - "min": 0.2, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + seed: int, + negative_prompt: str = "", + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -386,6 +420,11 @@ class StabilityUpscaleConservativeNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/conservative", @@ -401,7 +440,7 @@ class StabilityUpscaleConservativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -411,77 +450,81 @@ class StabilityUpscaleConservativeNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleCreativeNode: +class StabilityUpscaleCreativeNode(comfy_io.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleCreativeNode", + display_name="Stability AI Upscale Creative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Float.Input( + "creativity", + default=0.3, + min=0.1, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.3, - "min": 0.1, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + style_preset: str, + seed: int, + negative_prompt: str = "", + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -494,6 +537,11 @@ class StabilityUpscaleCreativeNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/creative", @@ -510,7 +558,7 @@ class StabilityUpscaleCreativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -525,7 +573,8 @@ class StabilityUpscaleCreativeNode: completed_statuses=[StabilityPollStatus.finished], failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=kwargs, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, ) response_poll: StabilityResultsGetResponse = await operation.execute() @@ -535,41 +584,48 @@ class StabilityUpscaleCreativeNode: image_data = base64.b64decode(response_poll.result) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleFastNode: +class StabilityUpscaleFastNode(comfy_io.ComfyNode): """ Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleFastNode", + display_name="Stability AI Upscale Fast", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, **kwargs): + async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput: image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/fast", @@ -580,7 +636,7 @@ class StabilityUpscaleFastNode: request=EmptyRequest(), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -590,24 +646,20 @@ class StabilityUpscaleFastNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "StabilityStableImageUltraNode": StabilityStableImageUltraNode, - "StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node, - "StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode, - "StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode, - "StabilityUpscaleFastNode": StabilityUpscaleFastNode, -} +class StabilityExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + StabilityStableImageUltraNode, + StabilityStableImageSD_3_5Node, + StabilityUpscaleConservativeNode, + StabilityUpscaleCreativeNode, + StabilityUpscaleFastNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "StabilityStableImageUltraNode": "Stability AI Stable Image Ultra", - "StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image", - "StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative", - "StabilityUpscaleCreativeNode": "Stability AI Upscale Creative", - "StabilityUpscaleFastNode": "Stability AI Upscale Fast", -} + +async def comfy_entrypoint() -> StabilityExtension: + return StabilityExtension() From fe442fac2eccd0cc66999b48d3c518623cafe4fc Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 31 Aug 2025 06:21:58 +0300 Subject: [PATCH 143/325] convert Primitive nodes to V3 schema (#9372) --- comfy_extras/nodes_primitive.py | 169 +++++++++++++++++--------------- 1 file changed, 90 insertions(+), 79 deletions(-) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 1f93f87a7..5a1aeba80 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -1,98 +1,109 @@ -# Primitive nodes that are evaluated at backend. -from __future__ import annotations - import sys +from typing_extensions import override -from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO +from comfy_api.latest import ComfyExtension, io -class String(ComfyNodeABC): +class String(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveString", + display_name="String", + category="utils/primitive", + inputs=[ + io.String.Input("value"), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) - - -class StringMultiline(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {"multiline": True,},)}, - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Int(ComfyNodeABC): +class StringMultiline(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveStringMultiline", + display_name="String (Multiline)", + category="utils/primitive", + inputs=[ + io.String.Input("value", multiline=True), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.INT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: int) -> tuple[int]: - return (value,) - - -class Float(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})}, - } - - RETURN_TYPES = (IO.FLOAT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: float) -> tuple[float]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Boolean(ComfyNodeABC): +class Int(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.BOOLEAN, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveInt", + display_name="Int", + category="utils/primitive", + inputs=[ + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + ], + outputs=[io.Int.Output()], + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: bool) -> tuple[bool]: - return (value,) + @classmethod + def execute(cls, value: int) -> io.NodeOutput: + return io.NodeOutput(value) -NODE_CLASS_MAPPINGS = { - "PrimitiveString": String, - "PrimitiveStringMultiline": StringMultiline, - "PrimitiveInt": Int, - "PrimitiveFloat": Float, - "PrimitiveBoolean": Boolean, -} +class Float(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveFloat", + display_name="Float", + category="utils/primitive", + inputs=[ + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + ], + outputs=[io.Float.Output()], + ) -NODE_DISPLAY_NAME_MAPPINGS = { - "PrimitiveString": "String", - "PrimitiveStringMultiline": "String (Multiline)", - "PrimitiveInt": "Int", - "PrimitiveFloat": "Float", - "PrimitiveBoolean": "Boolean", -} + @classmethod + def execute(cls, value: float) -> io.NodeOutput: + return io.NodeOutput(value) + + +class Boolean(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveBoolean", + display_name="Boolean", + category="utils/primitive", + inputs=[ + io.Boolean.Input("value"), + ], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, value: bool) -> io.NodeOutput: + return io.NodeOutput(value) + + +class PrimitivesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + String, + StringMultiline, + Int, + Float, + Boolean, + ] + +async def comfy_entrypoint() -> PrimitivesExtension: + return PrimitivesExtension() From 32a627bf1feadb83abba97906a27978b927abd33 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 31 Aug 2025 12:01:45 +0800 Subject: [PATCH 144/325] SEEDS: update noise decomposition and refactor (#9633) - Update the decomposition to reflect interval dependency - Extract phi computations into functions - Use torch.lerp for interpolation --- comfy/k_diffusion/sampling.py | 135 ++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fe6844b17..2d7e09838 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): return sigmas +def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_1(h) in exponential integrator methods.""" + return torch.expm1(h) + + +def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_2(h) in exponential integrator methods.""" + return (torch.expm1(h) - h) / h + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + fac = 1 / (2 * r) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r * h - fac = 1 / (2 * r) - sigma_s_1 = sigma_fn(lambda_s_1) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r) + sigma_s_1 = sigma_fn(lambda_s_1) - coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r < 1 - noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + # Step 2 + denoised_d = torch.lerp(denoised, denoised_2, fac) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + if inject_noise: + segment_factor = (r - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r_1 * h - lambda_s_2 = lambda_s + r_2 * h - sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_s_2 = sigma_s_2 * lambda_s_2.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1) + lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2) + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r_1 < r_2 < 1 - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() - noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) - if inject_noise: - x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + # Step 2 + a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta) + a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2) + if inject_noise: + segment_factor = (r_1 - r_2) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) + x_3 = x_3 + sde_noise * sigma_s_2 * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) - # Step 3 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + # Step 3 + b3 = ei_h_phi_2(-h_eta) / r_2 + b1 = ei_h_phi_1(-h_eta) - b3 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3) + if inject_noise: + segment_factor = (r_2 - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x From 9b151559721ff6c8d93150f3d8a53259a23911cd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:32:10 -0700 Subject: [PATCH 145/325] Probably not necessary anymore. (#9646) --- main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/main.py b/main.py index b23d50816..c33f0e17b 100644 --- a/main.py +++ b/main.py @@ -113,7 +113,6 @@ import gc if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' - logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": if args.default_device is not None: From 27e067ce505c102fd0f2be0f1242016c59a6816f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:54:02 -0700 Subject: [PATCH 146/325] Implement the USO subject identity lora. (#9674) Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso" and a ReferenceLatent node with the reference image. --- comfy/ldm/flux/model.py | 10 ++++++++-- comfy/lora.py | 4 ++++ comfy/lora_convert.py | 19 +++++++++++++++++++ comfy_extras/nodes_flux.py | 2 +- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1344c3a57..1e62f4626 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -233,12 +233,18 @@ class Flux(nn.Module): h = 0 w = 0 index = 0 - index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" + ref_latents_method = kwargs.get("ref_latents_method", "offset") for ref in ref_latents: - if index_ref_method: + if ref_latents_method == "index": index += 1 h_offset = 0 w_offset = 0 + elif ref_latents_method == "uso": + index = 0 + h_offset = h_len * patch_size + h + w_offset = w_len * patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] else: index = 1 h_offset = 0 diff --git a/comfy/lora.py b/comfy/lora.py index 00358884b..4a44f1318 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + for k in sdk: + hidden_size = model.model_config.unet_config.get("hidden_size", 0) + if k.endswith(".weight") and ".linear1." in k: + key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3)) if isinstance(model, comfy.model_base.GenmoMochi): for k in sdk: diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 3e00b63db..9d8d21efe 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux def convert_lora_wan_fun(sd): #Wan Fun loras return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) +def convert_uso_lora(sd): + sd_out = {} + for k in sd: + tensor = sd[k] + k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight") + .replace(".up.weight", ".lora_up.weight") + .replace(".qkv_lora2.", ".txt_attn.qkv.") + .replace(".qkv_lora1.", ".img_attn.qkv.") + .replace(".proj_lora1.", ".img_attn.proj.") + .replace(".proj_lora2.", ".txt_attn.proj.") + .replace(".qkv_lora.", ".linear1_qkv.") + .replace(".proj_lora.", ".linear2.") + .replace(".processor.", ".") + ) + sd_out[k_to] = tensor + return sd_out + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: return convert_lora_wan_fun(sd) + if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd: + return convert_uso_lora(sd) return sd diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index c8db75bb3..1bf7ddd92 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod: def INPUT_TYPES(s): return {"required": { "conditioning": ("CONDITIONING", ), - "reference_latents_method": (("offset", "index"), ), + "reference_latents_method": (("offset", "index", "uso"), ), }} RETURN_TYPES = ("CONDITIONING",) From e2d1e5dad98dbbcf505703ea8663f20101e6570a Mon Sep 17 00:00:00 2001 From: contentis Date: Tue, 2 Sep 2025 02:33:50 +0200 Subject: [PATCH 147/325] Enable Convolution AutoTuning (#9301) --- comfy/cli_args.py | 1 + comfy/ops.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index de3e85c08..72eeaea9a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -143,6 +143,7 @@ class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" + AutoTune = "autotune" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") diff --git a/comfy/ops.py b/comfy/ops.py index 18e7db705..55e958adb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError): cast_to = comfy.model_management.cast_to #TODO: remove once no more references +if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: + torch.backends.cudnn.benchmark = True + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) From 3412d53b1d69e4dfedf7e86c3092cea085094053 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Sep 2025 12:36:22 -0700 Subject: [PATCH 148/325] USO style reference. (#9677) Load the projector.safetensors file with the ModelPatchLoader node and use the siglip_vision_patch14_384.safetensors "clip vision" model and the USOStyleReferenceNode. --- comfy/clip_model.py | 12 +- comfy/clip_vision.py | 18 ++- comfy/ldm/flux/model.py | 11 +- comfy/model_patcher.py | 3 + comfy_extras/nodes_model_patch.py | 186 +++++++++++++++++++++++++++++- 5 files changed, 222 insertions(+), 8 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7e47d8a55..7c0cadab5 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module): def forward(self, x, mask=None, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None @@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + return x, intermediate class CLIPEmbeddings(torch.nn.Module): diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164..2fa410cb7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -50,7 +50,13 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) + model_type = config.get("model_type", "clip_vision_model") + model_class = IMAGE_ENCODERS.get(model_type) + if model_type == "siglip_vision_model": + self.return_all_hidden_states = True + else: + self.return_all_hidden_states = False + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -68,12 +74,18 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + if self.return_all_hidden_states: + all_hs = out[1].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = all_hs[:, -2] + outputs["all_hidden_states"] = all_hs + else: + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1e62f4626..d4be6bb61 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -106,6 +106,7 @@ class Flux(nn.Module): if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -117,9 +118,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + img = out["img"] + txt = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + if img_ids is not None: ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a944cb421..1fd03d9d1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -433,6 +433,9 @@ class ModelPatcher: def set_model_double_block_patch(self, patch): self.set_model_patch(patch, "double_block") + def set_model_post_input_patch(self, patch): + self.set_model_patch(patch, "post_input") + def add_object_patch(self, name, obj): self.object_patches[name] = obj diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 65e766b52..783c59b6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -1,4 +1,5 @@ import torch +from torch import nn import folder_paths import comfy.utils import comfy.ops @@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): return self.controlnet_blocks[block_id](img, controlnet_conditioning) +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 729, + style_token_nums: int = 64, + siglip_token_dims: int = 1152, + hidden_size: int = 3072, + context_layer_norm: bool = True, + device=None, dtype=None, operations=None + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs[2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs[1], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs[0], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layersmodel_patch + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -73,9 +204,14 @@ class ModelPatchLoader: model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) - # TODO: this node will work with more types of model patches - additional_in_dim = sd["img_in.weight"].shape[1] - 64 - model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + + if 'controlnet_blocks.0.y_rms.weight' in sd: + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'feature_embedder.mid_layer_norm.bias' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) + model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) @@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet: return (model_patched,) +class UsoStyleProjectorPatch: + def __init__(self, model_patch, encoded_image): + self.model_patch = model_patch + self.encoded_image = encoded_image + + def __call__(self, kwargs): + txt_ids = kwargs.get("txt_ids") + txt = kwargs.get("txt") + siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype) + txt = torch.cat([siglip_embedding, txt], dim=1) + kwargs['txt'] = txt + kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1) + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + +class USOStyleReference: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/model_patches/flux" + + def apply_patch(self, model, model_patch, clip_vision_output): + encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) + model_patched = model.clone() + model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image)) + return (model_patched,) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "USOStyleReference": USOStyleReference, } From e3018c2a5aeb99f0c5b595621949a451686ce55a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Sep 2025 13:12:07 -0700 Subject: [PATCH 149/325] uso -> uxo/uno as requested. (#9688) --- comfy/ldm/flux/model.py | 2 +- comfy_extras/nodes_flux.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index d4be6bb61..8ea7d4f57 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -248,7 +248,7 @@ class Flux(nn.Module): index += 1 h_offset = 0 w_offset = 0 - elif ref_latents_method == "uso": + elif ref_latents_method == "uxo": index = 0 h_offset = h_len * patch_size + h w_offset = w_len * patch_size + w diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 1bf7ddd92..25e029ffd 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod: def INPUT_TYPES(s): return {"required": { "conditioning": ("CONDITIONING", ), - "reference_latents_method": (("offset", "index", "uso"), ), + "reference_latents_method": (("offset", "index", "uxo/uno"), ), }} RETURN_TYPES = ("CONDITIONING",) @@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod: CATEGORY = "advanced/conditioning/flux" def append(self, conditioning, reference_latents_method): + if "uxo" in reference_latents_method or "uso" in reference_latents_method: + reference_latents_method = "uxo" c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) return (c, ) From 464ba1d6140eda6a0173703ac00c69f7fddab6ba Mon Sep 17 00:00:00 2001 From: Deep Roy Date: Tue, 2 Sep 2025 19:41:10 -0400 Subject: [PATCH 150/325] Accept prompt_id in interrupt handler (#9607) * Accept prompt_id in interrupt handler * remove a log --- server.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 8f9c88ebf..3d323eaf8 100644 --- a/server.py +++ b/server.py @@ -729,7 +729,34 @@ class PromptServer(): @routes.post("/interrupt") async def post_interrupt(request): - nodes.interrupt_processing() + try: + json_data = await request.json() + except json.JSONDecodeError: + json_data = {} + + # Check if a specific prompt_id was provided for targeted interruption + prompt_id = json_data.get('prompt_id') + if prompt_id: + currently_running, _ = self.prompt_queue.get_current_queue() + + # Check if the prompt_id matches any currently running prompt + should_interrupt = False + for item in currently_running: + # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) + if item[1] == prompt_id: + logging.info(f"Interrupting prompt {prompt_id}") + should_interrupt = True + break + + if should_interrupt: + nodes.interrupt_processing() + else: + logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") + else: + # No prompt_id provided, do a global interrupt + logging.info("Global interrupt (no prompt_id specified)") + nodes.interrupt_processing() + return web.Response(status=200) @routes.post("/free") From 1bcb469089a71fb1946b9f14e994df1b42b83def Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Sep 2025 17:05:57 -0700 Subject: [PATCH 151/325] ImageScaleToMaxDimension node. (#9689) --- comfy_extras/nodes_images.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fba80e2ae..392aea32c 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -625,6 +625,37 @@ class ImageFlip: return (image,) +class ImageScaleToMaxDimension: + upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "upscale_method": (s.upscale_methods,), + "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, largest_size): + height = image.shape[1] + width = image.shape[2] + + if height > width: + width = round((width / height) * largest_size) + height = largest_size + elif width > height: + height = round((height / width) * largest_size) + width = largest_size + else: + height = largest_size + width = largest_size + + samples = image.movedim(-1, 1) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1, -1) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, @@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = { "GetImageSize": GetImageSize, "ImageRotate": ImageRotate, "ImageFlip": ImageFlip, + "ImageScaleToMaxDimension": ImageScaleToMaxDimension, } From 4f5812b93712e0f52ae8fe80a89e8b5e7d0fa309 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 3 Sep 2025 08:06:41 +0800 Subject: [PATCH 152/325] Update template to 0.1.73 (#9686) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7f64aacca..4ebe6cc2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.70 +comfyui-workflow-templates==0.1.73 comfyui-embedded-docs==0.2.6 torch torchsde From 26d5b86da8ceb4589ee70f12ff2209b93a2d99e0 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Sep 2025 23:17:07 +0300 Subject: [PATCH 153/325] feat(api-nodes): add ByteDance Image nodes (#9477) --- comfy_api_nodes/nodes_bytedance.py | 336 +++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 337 insertions(+) create mode 100644 comfy_api_nodes/nodes_bytedance.py diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py new file mode 100644 index 000000000..fb6aba7fa --- /dev/null +++ b/comfy_api_nodes/nodes_bytedance.py @@ -0,0 +1,336 @@ +import logging +from enum import Enum +from typing import Optional +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import ( + validate_image_aspect_ratio_range, + get_number_of_images, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, +) +from comfy_api_nodes.apinode_utils import download_url_to_image_tensor, upload_images_to_comfyapi, validate_string + + +BYTEPLUS_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" + + +class Text2ImageModelName(str, Enum): + seedream3 = "seedream-3-0-t2i-250415" + + +class Image2ImageModelName(str, Enum): + seededit3 = "seededit-3-0-i2i-250628" + + +class Text2ImageTaskCreationRequest(BaseModel): + model: Text2ImageModelName = Text2ImageModelName.seedream3 + prompt: str = Field(...) + response_format: Optional[str] = Field("url") + size: Optional[str] = Field(None) + seed: Optional[int] = Field(0, ge=0, le=2147483647) + guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) + watermark: Optional[bool] = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: Image2ImageModelName = Image2ImageModelName.seededit3 + prompt: str = Field(...) + response_format: Optional[str] = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: Optional[str] = Field("adaptive") + seed: Optional[int] = Field(..., ge=0, le=2147483647) + guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) + watermark: Optional[bool] = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + + +def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: + if response.error: + error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("ByteDance task succeeded, image URL: %s", response.data[0]["url"]) + return response.data[0]["url"] + + +class ByteDanceImageNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageNode", + display_name="ByteDance Image", + category="api node/image/ByteDance", + description="Generate images using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Text2ImageModelName], + default=Text2ImageModelName.seedream3.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + comfy_io.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS], + tooltip="Pick a recommended size. Select Custom to use the width and height below", + ), + comfy_io.Int.Input( + "width", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + ), + comfy_io.Int.Input( + "height", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + comfy_io.Float.Input( + "guidance_scale", + default=2.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + size_preset: str, + width: int, + height: int, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (512 <= w <= 2048) or not (512 <= h <= 2048): + raise ValueError( + f"Custom size out of range: {w}x{h}. " + "Both width and height must be between 512 and 2048 pixels." + ) + + payload = Text2ImageTaskCreationRequest( + model=model, + prompt=prompt, + size=f"{w}x{h}", + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_ENDPOINT, + method=HttpMethod.POST, + request_model=Text2ImageTaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceImageEditNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageEditNode", + display_name="ByteDance Image Edit", + category="api node/video/ByteDance", + description="Edit images using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Image2ImageModelName], + default=Image2ImageModelName.seededit3.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "image", + tooltip="The base image to edit", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + comfy_io.Float.Input( + "guidance_scale", + default=5.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + source_url = (await upload_images_to_comfyapi( + image, + max_images=1, + mime_type="image/png", + auth_kwargs=auth_kwargs, + ))[0] + payload = Image2ImageTaskCreationRequest( + model=model, + prompt=prompt, + image=source_url, + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_ENDPOINT, + method=HttpMethod.POST, + request_model=Image2ImageTaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + ByteDanceImageNode, + ByteDanceImageEditNode, + ] + +async def comfy_entrypoint() -> ByteDanceExtension: + return ByteDanceExtension() diff --git a/nodes.py b/nodes.py index 0aff6b14a..6c2f9dd14 100644 --- a/nodes.py +++ b/nodes.py @@ -2344,6 +2344,7 @@ async def init_builtin_api_nodes(): "nodes_veo2.py", "nodes_kling.py", "nodes_bfl.py", + "nodes_bytedance.py", "nodes_luma.py", "nodes_recraft.py", "nodes_pixverse.py", From 50333f1715c03aa4100711eb6d075516a4021d24 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Sep 2025 23:17:37 +0300 Subject: [PATCH 154/325] api nodes(Ideogram): add Ideogram Character (#9616) * api nodes(Ideogram): add Ideogram Character * rename renderingSpeed default value from 'balanced' to 'default' --- comfy_api_nodes/apis/__init__.py | 22 ++++++++- comfy_api_nodes/nodes_ideogram.py | 77 ++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 7a09df55b..78a23db30 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum): class StyleType1(str, Enum): + AUTO = 'AUTO' GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' class ImagenImageGenerationInstance(BaseModel): @@ -2676,7 +2680,7 @@ class ReleaseNote(BaseModel): class RenderingSpeed(str, Enum): - BALANCED = 'BALANCED' + DEFAULT = 'DEFAULT' TURBO = 'TURBO' QUALITY = 'QUALITY' @@ -4918,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel): None, description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class IdeogramV3Request(BaseModel): @@ -4951,6 +4963,14 @@ class IdeogramV3Request(BaseModel): style_type: Optional[StyleType1] = Field( None, description='The type of style to apply' ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class ImagenGenerateImageResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d28895f3e..2d1c32e4f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -255,6 +255,7 @@ class IdeogramV1(comfy_io.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -383,6 +384,7 @@ class IdeogramV2(comfy_io.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -552,6 +554,7 @@ class IdeogramV3(comfy_io.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -612,11 +615,21 @@ class IdeogramV3(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "rendering_speed", - options=["BALANCED", "TURBO", "QUALITY"], - default="BALANCED", + options=["DEFAULT", "TURBO", "QUALITY"], + default="DEFAULT", tooltip="Controls the trade-off between generation speed and quality", optional=True, ), + comfy_io.Image.Input( + "character_image", + tooltip="Image to use as character reference.", + optional=True, + ), + comfy_io.Mask.Input( + "character_mask", + tooltip="Optional mask for character reference image.", + optional=True, + ), ], outputs=[ comfy_io.Image.Output(), @@ -639,12 +652,46 @@ class IdeogramV3(comfy_io.ComfyNode): magic_prompt_option="AUTO", seed=0, num_images=1, - rendering_speed="BALANCED", + rendering_speed="DEFAULT", + character_image=None, + character_mask=None, ): auth = { "auth_token": cls.hidden.auth_token_comfy_org, "comfy_api_key": cls.hidden.api_key_comfy_org, } + if rendering_speed == "BALANCED": # for backward compatibility + rendering_speed = "DEFAULT" + + character_img_binary = None + character_mask_binary = None + + if character_image is not None: + input_tensor = character_image.squeeze().cpu() + if character_mask is not None: + character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False) + character_mask = 1.0 - character_mask + if character_mask.shape[1:] != character_image.shape[1:-1]: + raise Exception("Character mask and image must be the same size") + + mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_byte_arr = BytesIO() + mask_img.save(mask_byte_arr, format="PNG") + mask_byte_arr.seek(0) + character_mask_binary = mask_byte_arr + character_mask_binary.name = "mask.png" + + img_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_np) + img_byte_arr = BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + character_img_binary = img_byte_arr + character_img_binary.name = "image.png" + elif character_mask is not None: + raise Exception("Character mask requires character image to be present") + # Check if both image and mask are provided for editing mode if image is not None and mask is not None: # Edit mode @@ -693,6 +740,15 @@ class IdeogramV3(comfy_io.ComfyNode): if num_images > 1: edit_request.num_images = num_images + files = { + "image": img_binary, + "mask": mask_binary, + } + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + # Execute the operation for edit mode operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -702,10 +758,7 @@ class IdeogramV3(comfy_io.ComfyNode): response_model=IdeogramGenerateResponse, ), request=edit_request, - files={ - "image": img_binary, - "mask": mask_binary, - }, + files=files, content_type="multipart/form-data", auth_kwargs=auth, ) @@ -739,6 +792,14 @@ class IdeogramV3(comfy_io.ComfyNode): if num_images > 1: gen_request.num_images = num_images + files = {} + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + if files: + gen_request.style_type = "AUTO" + # Execute the operation for generation mode operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -748,6 +809,8 @@ class IdeogramV3(comfy_io.ComfyNode): response_model=IdeogramGenerateResponse, ), request=gen_request, + files=files if files else None, + content_type="multipart/form-data", auth_kwargs=auth, ) From 22da0a83e9a251ca16b9753bf808bfa9f4b023d8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Sep 2025 23:18:27 +0300 Subject: [PATCH 155/325] [V3] convert Runway API nodes to the V3 schema (#9487) * convert RunAway API nodes to the V3 schema * fixed small typo * fix: add tooltip for "seed" input --- comfy_api_nodes/nodes_runway.py | 744 +++++++++++++++----------------- 1 file changed, 357 insertions(+), 387 deletions(-) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 98024a9fa..27b2bf748 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -12,6 +12,7 @@ User Guides: """ from typing import Union, Optional, Any +from typing_extensions import override from enum import Enum import torch @@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import ( validate_string, download_url_to_image_tensor, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -85,20 +86,11 @@ class RunwayGen3aAspectRatio(str, Enum): def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: """Returns the video URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -# TODO: replace with updated image validation utils (upstream) -def validate_input_image(image: torch.Tensor) -> bool: - """ - Validate the input image is within the size limits for the Runway API. - See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons - """ - return image.shape[2] < 8000 and image.shape[1] < 8000 - - async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, TaskStatusResponse], @@ -134,458 +126,438 @@ def extract_progress_from_task_status( def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: """Returns the image URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -class RunwayVideoGenNode(ComfyNodeABC): - """Runway Video Node Base.""" +async def get_response( + task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None +) -> TaskStatusResponse: + """Poll the task status until it is finished then get the response.""" + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_GET_TASK_STATUS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + estimated_duration=estimated_duration, + node_id=node_id, + ) - RETURN_TYPES = ("VIDEO",) - FUNCTION = "api_call" - CATEGORY = "api node/video/Runway" - API_NODE = True - def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True +async def generate_video( + request: RunwayImageToVideoRequest, + auth_kwargs: dict[str, str], + node_id: Optional[str] = None, + estimated_duration: Optional[int] = None, +) -> VideoFromFile: + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=RunwayImageToVideoRequest, + response_model=RunwayImageToVideoResponse, + ), + request=request, + auth_kwargs=auth_kwargs, + ) - def validate_response(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no video data found in response." - ) - return True + initial_response = await initial_operation.execute() - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, + final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no video data found in response.") + + video_url = get_video_url_from_task_status(final_response) + return await download_url_to_video_output(video_url) + + +class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayImageToVideoNodeGen3a", + display_name="Runway Image to Video (Gen3a Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen3a Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", + ), + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", + ), + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], + ), + comfy_io.Combo.Input( + "ratio", + options=[model.value for model in RunwayGen3aAspectRatio], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def generate_video( - self, - request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, - ) -> tuple[VideoFromFile]: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, + @classmethod + async def execute( + cls, + prompt: str, + start_frame: torch.Tensor, + duration: str, + ratio: str, + seed: int, + ) -> comfy_io.NodeOutput: + validate_string(prompt, min_length=1) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + download_urls = await upload_images_to_comfyapi( + start_frame, + max_images=1, + mime_type="image/png", auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id - - final_response = await self.get_response(task_id, auth_kwargs, node_id) - self.validate_response(final_response) - - video_url = get_video_url_from_task_status(final_response) - return (await download_url_to_video_output(video_url),) + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + ) + ) -class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): - """Runway Image to Video Node using Gen3a Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo." +class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayImageToVideoNodeGen4", + display_name="Runway Image to Video (Gen4 Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen4 Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + comfy_io.Combo.Input( "ratio", - enum_type=RunwayGen3aAspectRatio, + options=[model.value for model in RunwayGen4TurboAspectRatio], ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + comfy_io.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, start_frame: torch.Tensor, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } - # Upload image download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen4_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): - """Runway Image to Video Node using Gen4 Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video." +class RunwayFirstLastFrameNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayFirstLastFrameNode", + display_name="Runway First-Last-Frame to Video", + category="api node/video/Runway", + description="Upload first and last keyframes, draft a prompt, and generate a video. " + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + comfy_io.Image.Input( + "end_frame", + tooltip="End frame to be used for the video. Supported for gen3a_turbo only.", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], + ), + comfy_io.Combo.Input( "ratio", - enum_type=RunwayGen4TurboAspectRatio, + options=[model.value for model in RunwayGen3aAspectRatio], ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + comfy_io.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, - prompt: str, - start_frame: torch.Tensor, - duration: str, - ratio: str, - seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs - validate_string(prompt, min_length=1) - validate_input_image(start_frame) - - # Upload image - download_urls = await upload_images_to_comfyapi( - start_frame, - max_images=1, - mime_type="image/png", - auth_kwargs=kwargs, - ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen4_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] - ), - ), - auth_kwargs=kwargs, - node_id=unique_id, - ) - - -class RunwayFirstLastFrameNode(RunwayVideoGenNode): - """Runway First-Last Frame Node.""" - - DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True - ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, - ), - "end_frame": ( - IO.IMAGE, - { - "tooltip": "End frame to be used for the video. Supported for gen3a_turbo only." - }, - ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration - ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, - "ratio", - enum_type=RunwayGen3aAspectRatio, - ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, - "seed", - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "unique_id": "UNIQUE_ID", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, start_frame: torch.Tensor, end_frame: torch.Tensor, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) - validate_input_image(end_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_dimensions(end_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } - # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), - ] + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ), + RunwayPromptImageDetailedObject( + uri=str(download_urls[1]), position="last" + ), + ] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayTextToImageNode(ComfyNodeABC): - """Runway Text to Image Node.""" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "api_call" - CATEGORY = "api node/image/Runway" - API_NODE = True - DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation." +class RunwayTextToImageNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayTextToImageNode", + display_name="Runway Text to Image", + category="api node/image/Runway", + description="Generate an image from a text prompt using Runway's Gen 4 model. " + "You can also include reference image to guide the generation.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayTextToImageRequest, + comfy_io.Combo.Input( "ratio", - enum_type=RunwayTextToImageAspectRatioEnum, + options=[model.value for model in RunwayTextToImageAspectRatioEnum], ), - }, - "optional": { - "reference_image": ( - IO.IMAGE, - {"tooltip": "Optional reference image to guide the generation"}, - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def validate_task_created(self, response: RunwayTextToImageResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True - - def validate_response(self, response: TaskStatusResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no image data found in response." - ) - return True - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_T2I_SECONDS, - node_id=node_id, + comfy_io.Image.Input( + "reference_image", + tooltip="Optional reference image to guide the generation", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, ratio: str, reference_image: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[torch.Tensor]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + # Prepare reference images if provided reference_images = None if reference_image is not None: - validate_input_image(reference_image) + validate_image_dimensions(reference_image, max_width=7999, max_height=7999) + validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) download_urls = await upload_images_to_comfyapi( reference_image, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload reference image to comfy api.") - reference_images = [ReferenceImage(uri=str(download_urls[0]))] - # Create request request = RunwayTextToImageRequest( promptText=prompt, model=Model4.gen4_image, @@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC): referenceImages=reference_images, ) - # Execute initial request initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_TEXT_TO_IMAGE, @@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC): response_model=RunwayTextToImageResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) initial_response = await initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id # Poll for completion - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + initial_response.id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) - self.validate_response(final_response) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no image data found in response.") - # Download and return image - image_url = get_image_url_from_task_status(final_response) - return (await download_url_to_image_tensor(image_url),) + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) -NODE_CLASS_MAPPINGS = { - "RunwayFirstLastFrameNode": RunwayFirstLastFrameNode, - "RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a, - "RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4, - "RunwayTextToImageNode": RunwayTextToImageNode, -} +class RunwayExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + RunwayFirstLastFrameNode, + RunwayImageToVideoNodeGen3a, + RunwayImageToVideoNodeGen4, + RunwayTextToImageNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video", - "RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)", - "RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)", - "RunwayTextToImageNode": "Runway Text to Image", -} +async def comfy_entrypoint() -> RunwayExtension: + return RunwayExtension() From 4368d8f87f580f526e8b2bc43bf0ac88e2b67033 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:43:29 -0700 Subject: [PATCH 156/325] Update comment in api example. (#9708) --- script_examples/basic_api_example.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index 9128420c4..7e20cc2c1 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -3,11 +3,7 @@ from urllib import request #This is the ComfyUI api prompt format. -#If you want it for a specific workflow you can "enable dev mode options" -#in the settings of the UI (gear beside the "Queue Size: ") this will enable -#a button on the UI to save workflows in api format. - -#keep in mind ComfyUI is pre alpha software so this format will change a bit. +#If you want it for a specific workflow you can "File -> Export (API)" in the interface. #this is the one for the default workflow prompt_text = """ From f48d05a2d17fe1a69e08fbabfb080e3779b36225 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 4 Sep 2025 04:21:38 +0300 Subject: [PATCH 157/325] convert AlignYourStepsScheduler node to V3 schema (#9226) --- comfy_extras/nodes_align_your_steps.py | 50 +++++++++++++++++--------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 8d856d0e8..edd5dadd4 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -1,6 +1,10 @@ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html import numpy as np import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def loglinear_interp(t_steps, num_steps): """ @@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152 "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} -class AlignYourStepsScheduler: +class AlignYourStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["SD1", "SDXL", "SVD"], ), - "steps": ("INT", {"default": 10, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AlignYourStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), + io.Int.Input("steps", default=10, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()], + ) def get_sigmas(self, model_type, steps, denoise): + # Deprecated: use the V3 schema's `execute` method instead of this. + return AlignYourStepsScheduler().execute(model_type, steps, denoise).result + + @classmethod + def execute(cls, model_type, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -46,8 +55,15 @@ class AlignYourStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "AlignYourStepsScheduler": AlignYourStepsScheduler, -} + +class AlignYourStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AlignYourStepsScheduler, + ] + +async def comfy_entrypoint() -> AlignYourStepsExtension: + return AlignYourStepsExtension() From 72855db715096bc378817b1aaffcf232fdc39659 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Sep 2025 19:20:13 -0700 Subject: [PATCH 158/325] Fix potential rope issue. (#9710) --- comfy/ldm/audio/dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67e..d0d69bbdc 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module): # Attention layers if self.rotary_pos_emb is not None: - rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device) + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) else: rotary_pos_emb = None From b71f9bcb7143b8cd4fff627bb91b60739c915d4c Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 4 Sep 2025 14:14:02 +0800 Subject: [PATCH 159/325] Update template to 0.1.75 (#9711) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4ebe6cc2a..3008a5dc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.73 +comfyui-workflow-templates==0.1.75 comfyui-embedded-docs==0.2.6 torch torchsde From b0338e930bbc1f9d01f005f224573d5994732932 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Sep 2025 02:15:57 -0400 Subject: [PATCH 160/325] ComfyUI 0.3.57 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index e8e039373..4cc3c8647 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.56" +__version__ = "0.3.57" diff --git a/pyproject.toml b/pyproject.toml index cfd5d45ef..d75cd04a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.56" +version = "0.3.57" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From a9f1bb10a52ce08a3f21e6fc562554671c85c3d5 Mon Sep 17 00:00:00 2001 From: guill Date: Thu, 4 Sep 2025 16:13:28 -0700 Subject: [PATCH 161/325] Fix progress update crossover between users (#9706) * Fix showing progress from other sessions Because `client_id` was missing from ths `progress_state` message, it was being sent to all connected sessions. This technically meant that if someone had a graph with the same nodes, they would see the progress updates for others. Also added a test to prevent reoccurance and moved the tests around to make CI easier to hook up. * Fix CI issues related to timing-sensitive tests --- .github/workflows/test-execution.yml | 30 +++ comfy_execution/progress.py | 3 +- tests/conftest.py | 6 + .../extra_model_paths.yaml | 0 .../test_async_nodes.py | 28 ++- .../test_execution.py | 13 +- tests/execution/test_progress_isolation.py | 233 ++++++++++++++++++ .../testing_nodes/testing-pack/__init__.py | 0 .../testing-pack/api_test_nodes.py | 0 .../testing-pack/async_test_nodes.py | 0 .../testing_nodes/testing-pack/conditions.py | 0 .../testing-pack/flow_control.py | 0 .../testing-pack/specific_tests.py | 0 .../testing_nodes/testing-pack/stubs.py | 0 .../testing_nodes/testing-pack/tools.py | 0 .../testing_nodes/testing-pack/util.py | 0 16 files changed, 295 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/test-execution.yml rename tests/{inference => execution}/extra_model_paths.yaml (100%) rename tests/{inference => execution}/test_async_nodes.py (95%) rename tests/{inference => execution}/test_execution.py (98%) create mode 100644 tests/execution/test_progress_isolation.py rename tests/{inference => execution}/testing_nodes/testing-pack/__init__.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/api_test_nodes.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/async_test_nodes.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/conditions.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/flow_control.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/specific_tests.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/stubs.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/tools.py (100%) rename tests/{inference => execution}/testing_nodes/testing-pack/util.py (100%) diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml new file mode 100644 index 000000000..00ef07ebf --- /dev/null +++ b/.github/workflows/test-execution.yml @@ -0,0 +1,30 @@ +name: Execution Tests + +on: + push: + branches: [ main, master ] + pull_request: + branches: [ main, master ] + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install -r tests-unit/requirements.txt + - name: Run Execution Tests + run: | + python -m pytest tests/execution -v --skip-timing-checks diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index e8f5ede1e..f951a3350 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler): } # Send a combined progress_state message with all node states + # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes} + "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id ) @override diff --git a/tests/conftest.py b/tests/conftest.py index 4e30eb581..290e3a5c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ def pytest_addoption(parser): parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)") # This initializes args at the beginning of the test session @pytest.fixture(scope="session", autouse=True) @@ -19,6 +20,11 @@ def args_pytest(pytestconfig): return args +@pytest.fixture(scope="session") +def skip_timing_checks(pytestconfig): + """Fixture that returns whether timing checks should be skipped.""" + return pytestconfig.getoption("--skip-timing-checks") + def pytest_collection_modifyitems(items): # Modifies items so tests run in the correct order diff --git a/tests/inference/extra_model_paths.yaml b/tests/execution/extra_model_paths.yaml similarity index 100% rename from tests/inference/extra_model_paths.yaml rename to tests/execution/extra_model_paths.yaml diff --git a/tests/inference/test_async_nodes.py b/tests/execution/test_async_nodes.py similarity index 95% rename from tests/inference/test_async_nodes.py rename to tests/execution/test_async_nodes.py index f029953dd..c771b4b36 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -7,7 +7,7 @@ import subprocess from pytest import fixture from comfy_execution.graph_utils import GraphBuilder -from tests.inference.test_execution import ComfyClient, run_warmup +from tests.execution.test_execution import ComfyClient, run_warmup @pytest.mark.execution @@ -23,7 +23,7 @@ class TestAsyncNodes: '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] use_lru, lru_size = request.param @@ -81,7 +81,7 @@ class TestAsyncNodes: assert len(result_images) == 1, "Should have 1 image" assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" - def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): + def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test that multiple async nodes execute in parallel.""" # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -104,7 +104,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should take ~0.5s (max duration) not 1.2s (sum of durations) - assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" + if not skip_timing_checks: + assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" # Verify all nodes executed assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) @@ -150,7 +151,7 @@ class TestAsyncNodes: with pytest.raises(urllib.error.HTTPError): client.run(g) - def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): + def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test async nodes with lazy evaluation.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_lazy") @@ -173,7 +174,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should only execute sleep1, not sleep2 - assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" + if not skip_timing_checks: + assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" assert result.did_run(sleep1), "Sleep1 should have executed" assert not result.did_run(sleep2), "Sleep2 should have been skipped" @@ -310,7 +312,7 @@ class TestAsyncNodes: images = result.get_images(output) assert len(images) == 1, "Should have blocked second image" - def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): + def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test that async nodes are properly cached.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_cache") @@ -330,9 +332,10 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time assert not result2.did_run(sleep_node), "Should be cached" - assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" + if not skip_timing_checks: + assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" - def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): + def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test async nodes within dynamically generated prompts.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_dynamic") @@ -345,8 +348,8 @@ class TestAsyncNodes: dynamic_async = g.node("TestDynamicAsyncGeneration", image1=image1.out(0), image2=image2.out(0), - num_async_nodes=3, - sleep_duration=0.2) + num_async_nodes=5, + sleep_duration=0.4) g.node("SaveImage", images=dynamic_async.out(0)) start_time = time.time() @@ -354,7 +357,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should execute async nodes in parallel within dynamic prompt - assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" + if not skip_timing_checks: + assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s" assert result.did_run(dynamic_async) def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): diff --git a/tests/inference/test_execution.py b/tests/execution/test_execution.py similarity index 98% rename from tests/inference/test_execution.py rename to tests/execution/test_execution.py index e7b29302e..8ea05fdd8 100644 --- a/tests/inference/test_execution.py +++ b/tests/execution/test_execution.py @@ -149,7 +149,7 @@ class TestExecution: '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] use_lru, lru_size = request.param @@ -518,7 +518,7 @@ class TestExecution: assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert not result.did_run(test_node), "The execution should have been cached" - def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -541,14 +541,15 @@ class TestExecution: # The test should take around 3.0 seconds (the longest sleep duration) # plus some overhead, but definitely less than the sum of all sleeps (9.0s) - assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" + if not skip_timing_checks: + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" # Verify that all nodes executed assert result.did_run(sleep_node1), "Sleep node 1 should have run" assert result.did_run(sleep_node2), "Sleep node 2 should have run" assert result.did_run(sleep_node3), "Sleep node 3 should have run" - def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -574,7 +575,9 @@ class TestExecution: # Similar to the previous test, expect parallel execution of the sleep nodes # which should complete in less than the sum of all sleeps - assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s" + # Lots of leeway here since Windows CI is slow + if not skip_timing_checks: + assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s" # Verify the parallel sleep node executed assert result.did_run(parallel_sleep), "ParallelSleep node should have run" diff --git a/tests/execution/test_progress_isolation.py b/tests/execution/test_progress_isolation.py new file mode 100644 index 000000000..93dc0d41b --- /dev/null +++ b/tests/execution/test_progress_isolation.py @@ -0,0 +1,233 @@ +"""Test that progress updates are properly isolated between WebSocket clients.""" + +import json +import pytest +import time +import threading +import uuid +import websocket +from typing import List, Dict, Any +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +class ProgressTracker: + """Tracks progress messages received by a WebSocket client.""" + + def __init__(self, client_id: str): + self.client_id = client_id + self.progress_messages: List[Dict[str, Any]] = [] + self.lock = threading.Lock() + + def add_message(self, message: Dict[str, Any]): + """Thread-safe addition of progress messages.""" + with self.lock: + self.progress_messages.append(message) + + def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]: + """Get all progress messages for a specific prompt_id.""" + with self.lock: + return [ + msg for msg in self.progress_messages + if msg.get('data', {}).get('prompt_id') == prompt_id + ] + + def has_cross_contamination(self, own_prompt_id: str) -> bool: + """Check if this client received progress for other prompts.""" + with self.lock: + for msg in self.progress_messages: + msg_prompt_id = msg.get('data', {}).get('prompt_id') + if msg_prompt_id and msg_prompt_id != own_prompt_id: + return True + return False + + +class IsolatedClient(ComfyClient): + """Extended ComfyClient that tracks all WebSocket messages.""" + + def __init__(self): + super().__init__() + self.progress_tracker = None + self.all_messages: List[Dict[str, Any]] = [] + + def connect(self, listen='127.0.0.1', port=8188, client_id=None): + """Connect with a specific client_id and set up message tracking.""" + if client_id is None: + client_id = str(uuid.uuid4()) + super().connect(listen, port, client_id) + self.progress_tracker = ProgressTracker(client_id) + + def listen_for_messages(self, duration: float = 5.0): + """Listen for WebSocket messages for a specified duration.""" + end_time = time.time() + duration + self.ws.settimeout(0.5) # Non-blocking with timeout + + while time.time() < end_time: + try: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + self.all_messages.append(message) + + # Track progress_state messages + if message.get('type') == 'progress_state': + self.progress_tracker.add_message(message) + except websocket.WebSocketTimeoutException: + continue + except Exception: + # Log error silently in test context + break + + +@pytest.mark.execution +class TestProgressIsolation: + """Test suite for verifying progress update isolation between clients.""" + + @pytest.fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start the ComfyUI server for testing.""" + import subprocess + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + + def start_client_with_retry(self, listen: str, port: int, client_id: str = None): + """Start client with connection retries.""" + client = IsolatedClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen, port, client_id) + return client + except ConnectionRefusedError as e: + print(e) # noqa: T201 + print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 + raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") + + def test_progress_isolation_between_clients(self, args_pytest): + """Test that progress updates are isolated between different clients.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + # Create two separate clients with unique IDs + client_a_id = "client_a_" + str(uuid.uuid4()) + client_b_id = "client_b_" + str(uuid.uuid4()) + + try: + # Connect both clients with retries + client_a = self.start_client_with_retry(listen, port, client_a_id) + client_b = self.start_client_with_retry(listen, port, client_b_id) + + # Create simple workflows for both clients + graph_a = GraphBuilder(prefix="client_a") + image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + graph_a.node("PreviewImage", images=image_a.out(0)) + + graph_b = GraphBuilder(prefix="client_b") + image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + graph_b.node("PreviewImage", images=image_b.out(0)) + + # Submit workflows from both clients + prompt_a = graph_a.finalize() + prompt_b = graph_b.finalize() + + response_a = client_a.queue_prompt(prompt_a) + prompt_id_a = response_a['prompt_id'] + + response_b = client_b.queue_prompt(prompt_b) + prompt_id_b = response_b['prompt_id'] + + # Start threads to listen for messages on both clients + def listen_client_a(): + client_a.listen_for_messages(duration=10.0) + + def listen_client_b(): + client_b.listen_for_messages(duration=10.0) + + thread_a = threading.Thread(target=listen_client_a) + thread_b = threading.Thread(target=listen_client_b) + + thread_a.start() + thread_b.start() + + # Wait for threads to complete + thread_a.join() + thread_b.join() + + # Verify isolation + # Client A should only receive progress for prompt_id_a + assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \ + f"Client A received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_a}, but got messages for multiple prompts." + + # Client B should only receive progress for prompt_id_b + assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \ + f"Client B received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_b}, but got messages for multiple prompts." + + # Verify each client received their own progress updates + client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a) + client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b) + + assert len(client_a_messages) > 0, \ + "Client A did not receive any progress updates for its own workflow" + assert len(client_b_messages) > 0, \ + "Client B did not receive any progress updates for its own workflow" + + # Ensure no cross-contamination + client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b) + client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a) + + assert len(client_a_other) == 0, \ + f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow" + assert len(client_b_other) == 0, \ + f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow" + + finally: + # Clean up connections + if hasattr(client_a, 'ws'): + client_a.ws.close() + if hasattr(client_b, 'ws'): + client_b.ws.close() + + def test_progress_with_missing_client_id(self, args_pytest): + """Test that progress updates handle missing client_id gracefully.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + try: + # Connect client with retries + client = self.start_client_with_retry(listen, port) + + # Create a simple workflow + graph = GraphBuilder(prefix="test_missing_id") + image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1) + graph.node("PreviewImage", images=image.out(0)) + + # Submit workflow + prompt = graph.finalize() + response = client.queue_prompt(prompt) + prompt_id = response['prompt_id'] + + # Listen for messages + client.listen_for_messages(duration=5.0) + + # Should still receive progress updates for own workflow + messages = client.progress_tracker.get_messages_for_prompt(prompt_id) + assert len(messages) > 0, \ + "Client did not receive progress updates even though it initiated the workflow" + + finally: + if hasattr(client, 'ws'): + client.ws.close() + diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/execution/testing_nodes/testing-pack/__init__.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/__init__.py rename to tests/execution/testing_nodes/testing-pack/__init__.py diff --git a/tests/inference/testing_nodes/testing-pack/api_test_nodes.py b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/api_test_nodes.py rename to tests/execution/testing_nodes/testing-pack/api_test_nodes.py diff --git a/tests/inference/testing_nodes/testing-pack/async_test_nodes.py b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/async_test_nodes.py rename to tests/execution/testing_nodes/testing-pack/async_test_nodes.py diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/execution/testing_nodes/testing-pack/conditions.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/conditions.py rename to tests/execution/testing_nodes/testing-pack/conditions.py diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/execution/testing_nodes/testing-pack/flow_control.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/flow_control.py rename to tests/execution/testing_nodes/testing-pack/flow_control.py diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/specific_tests.py rename to tests/execution/testing_nodes/testing-pack/specific_tests.py diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/execution/testing_nodes/testing-pack/stubs.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/stubs.py rename to tests/execution/testing_nodes/testing-pack/stubs.py diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/execution/testing_nodes/testing-pack/tools.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/tools.py rename to tests/execution/testing_nodes/testing-pack/tools.py diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/execution/testing_nodes/testing-pack/util.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/util.py rename to tests/execution/testing_nodes/testing-pack/util.py From 261421e21899abc8168c71efd8694ade020bcee2 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 5 Sep 2025 03:36:20 +0300 Subject: [PATCH 162/325] Add Hunyuan 3D 2.1 Support (#8714) --- comfy/clip_vision.py | 231 ++++++++- comfy/image_encoders/dino2.py | 33 +- comfy/image_encoders/dino2_large.json | 22 + comfy/latent_formats.py | 5 + comfy/ldm/hunyuan3d/vae.py | 569 ++++++++++++++++++---- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 658 ++++++++++++++++++++++++++ comfy/model_base.py | 17 + comfy/model_detection.py | 14 + comfy/sd.py | 49 +- comfy/supported_models.py | 13 +- comfy_extras/nodes_hunyuan3d.py | 24 +- nodes.py | 29 +- requirements.txt | 2 +- 13 files changed, 1537 insertions(+), 129 deletions(-) create mode 100644 comfy/image_encoders/dino2_large.json create mode 100644 comfy/ldm/hunyuan3dv2_1/hunyuandit.py diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 2fa410cb7..4bc640e8b 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -17,10 +17,227 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + +def cubic_kernel(x, a: float = -0.75): + absx = x.abs() + absx2 = absx ** 2 + absx3 = absx ** 3 + + w = (a + 2) * absx3 - (a + 3) * absx2 + 1 + w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a + + return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x))) + +def get_indices_weights(in_size, out_size, scale): + # OpenCV-style half-pixel mapping + x = torch.arange(out_size, dtype=torch.float32) + x = (x + 0.5) / scale - 0.5 + + x0 = x.floor().long() + dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3)) + + weights = cubic_kernel(dx) + weights = weights / weights.sum(dim=1, keepdim=True) + + indices = x0.unsqueeze(1) + torch.arange(-1, 3) + indices = indices.clamp(0, in_size - 1) + + return indices, weights + +def resize_cubic_1d(x, out_size, dim): + b, c, h, w = x.shape + in_size = h if dim == 2 else w + scale = out_size / in_size + + indices, weights = get_indices_weights(in_size, out_size, scale) + + if dim == 2: + x = x.permute(0, 1, 3, 2) + x = x.reshape(-1, h) + else: + x = x.reshape(-1, w) + + gathered = x[:, indices] + out = (gathered * weights.unsqueeze(0)).sum(dim=2) + + if dim == 2: + out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2) + else: + out = out.reshape(b, c, h, out_size) + + return out + +def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor: + """ + Resize image using OpenCV-equivalent INTER_CUBIC interpolation. + Implemented in pure PyTorch + """ + + if img.ndim == 3: + img = img.unsqueeze(0) + + img = img.permute(0, 3, 1, 2) + + out_h, out_w = size + img = resize_cubic_1d(img, out_h, dim=2) + img = resize_cubic_1d(img, out_w, dim=3) + return img + +def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor: + # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch + original_shape = img.shape + is_hwc = False + + if img.ndim == 3: + if img.shape[0] <= 4: + img = img.unsqueeze(0) + else: + is_hwc = True + img = img.permute(2, 0, 1).unsqueeze(0) + elif img.ndim == 4: + pass + else: + raise ValueError("Expected image with 3 or 4 dims.") + + B, C, H, W = img.shape + out_h, out_w = size + scale_y = H / out_h + scale_x = W / out_w + + device = img.device + + # compute the grid boundries + y_start = torch.arange(out_h, device=device).float() * scale_y + y_end = y_start + scale_y + x_start = torch.arange(out_w, device=device).float() * scale_x + x_end = x_start + scale_x + + # for each output pixel, we will compute the range for it + y_start_int = torch.floor(y_start).long() + y_end_int = torch.ceil(y_end).long() + x_start_int = torch.floor(x_start).long() + x_end_int = torch.ceil(x_end).long() + + # We will build the weighted sums by iterating over contributing input pixels once + output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device) + area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device) + + max_kernel_h = int(torch.max(y_end_int - y_start_int).item()) + max_kernel_w = int(torch.max(x_end_int - x_start_int).item()) + + for dy in range(max_kernel_h): + for dx in range(max_kernel_w): + # compute the weights for this offset for all output pixels + + y_idx = y_start_int.unsqueeze(1) + dy + x_idx = x_start_int.unsqueeze(0) + dx + + # clamp indices to image boundaries + y_idx_clamped = torch.clamp(y_idx, 0, H - 1) + x_idx_clamped = torch.clamp(x_idx, 0, W - 1) + + # compute weights by broadcasting + y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0) + x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0) + + weight = (y_weight * x_weight) + + y_expand = y_idx_clamped.expand(out_h, out_w) + x_expand = x_idx_clamped.expand(out_h, out_w) + + + pixels = img[:, :, y_expand, x_expand] + + # unsqueeze to broadcast + w = weight.unsqueeze(0).unsqueeze(0) + + output += pixels * w + area += weight + + # Normalize by area + output /= area.unsqueeze(0).unsqueeze(0) + + if is_hwc: + return output[0].permute(1, 2, 0) + elif img.shape[0] == 1 and original_shape[0] <= 4: + return output[0] + else: + return output + +def recenter(image, border_ratio: float = 0.2): + + if image.shape[-1] == 4: + mask = image[..., 3] + else: + mask = torch.ones_like(image[..., 0:1]) * 255 + image = torch.concatenate([image, mask], axis=-1) + mask = mask[..., 0] + + H, W, C = image.shape + + size = max(H, W) + result = torch.zeros((size, size, C), dtype = torch.uint8) + + # as_tuple to match numpy behaviour + x_coords, y_coords = torch.nonzero(mask, as_tuple=True) + + y_min, y_max = y_coords.min(), y_coords.max() + x_min, x_max = x_coords.min(), x_coords.max() + + h = x_max - x_min + w = y_max - y_min + + if h == 0 or w == 0: + raise ValueError('input image is empty') + + desired_size = int(size * (1 - border_ratio)) + scale = desired_size / max(h, w) + + h2 = int(h * scale) + w2 = int(w * scale) + + x2_min = (size - h2) // 2 + x2_max = x2_min + h2 + + y2_min = (size - w2) // 2 + y2_max = y2_min + w2 + + # note: opencv takes columns first (opposite to pytorch and numpy that take the row first) + result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2)) + + bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255 + + mask = result[..., 3:].to(torch.float32) / 255 + result = result[..., :3] * mask + bg * (1 - mask) + + mask = mask * 255 + result = result.clip(0, 255).to(torch.uint8) + mask = mask.clip(0, 255).to(torch.uint8) + + return result + +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], + crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512): + + if border_ratio is not None: + + image = (image * 255).clamp(0, 255).to(torch.uint8) + image = [recenter(img, border_ratio = border_ratio) for img in image] + + image = torch.stack(image, dim = 0) + image = resize_cubic(image, size = (recenter_size, recenter_size)) + + image = image / 255 * 2 - 1 + low, high = value_range + + image = (image - low) / (high - low) + image = image.permute(0, 2, 3, 1) + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): if crop: @@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] @@ -71,9 +288,9 @@ class ClipVisionModel(): def get_sd(self): return self.model.state_dict() - def encode_image(self, image, crop=True): + def encode_image(self, image, crop=True, border_ratio: float = None): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() @@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") - elif "embeddings.patch_embeddings.projection.weight" in sd: + + # Dinov2 + elif 'encoder.layer.39.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") + elif 'encoder.layer.23.layer_scale2.lambda1' in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") else: return None diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 976f98c65..9b6dace9d 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module): def forward(self, x): return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) +class Dinov2MLP(torch.nn.Module): + def __init__(self, hidden_size: int, dtype, device, operations): + super().__init__() + + mlp_ratio = 4 + hidden_features = int(hidden_size * mlp_ratio) + self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = torch.nn.functional.gelu(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state class SwiGLUFFN(torch.nn.Module): def __init__(self, dim, dtype, device, operations): @@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): super().__init__() self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) - self.mlp = SwiGLUFFN(dim, dtype, device, operations) + if use_swiglu_ffn: + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + else: + self.mlp = Dinov2MLP(dim, dtype, device, operations) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) @@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module): class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + for _ in range(num_layers)]) def forward(self, x, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) @@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module): intermediate_output = len(self.layer) + intermediate_output intermediate = None - for i, l in enumerate(self.layer): - x = l(x, optimized_attention) + for i, layer in enumerate(self.layer): + x = layer(x, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate @@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module): dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] + use_swiglu_ffn = config_dict["use_swiglu_ffn"] self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json new file mode 100644 index 000000000..43fbb58ff --- /dev/null +++ b/comfy/image_encoders/dino2_large.json @@ -0,0 +1,22 @@ +{ + "hidden_size": 1024, + "use_mask_token": true, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_attention_heads": 16, + "initializer_range": 0.02, + "attention_probs_dropout_prob": 0.0, + "hidden_dropout_prob": 0.0, + "hidden_act": "gelu", + "mlp_ratio": 4, + "model_type": "dinov2", + "num_hidden_layers": 24, + "layer_norm_eps": 1e-6, + "qkv_bias": true, + "use_swiglu_ffn": false, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index caf4991fc..0d84994b0 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat): latent_dimensions = 1 scale_factor = 0.9990943042622529 +class Hunyuan3Dv2_1(LatentFormat): + scale_factor = 1.0039506158752403 + latent_channels = 64 + latent_dimensions = 1 + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 6e8cbf1d9..760944827 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -4,81 +4,458 @@ import torch import torch.nn as nn import torch.nn.functional as F - - -from typing import Union, Tuple, List, Callable, Optional - import numpy as np -from einops import repeat, rearrange +import math from tqdm import tqdm + +from typing import Optional + import logging import comfy.ops ops = comfy.ops.disable_weight_init -def generate_dense_grid_points( - bbox_min: np.ndarray, - bbox_max: np.ndarray, - octree_resolution: int, - indexing: str = "ij", -): - length = bbox_max - bbox_min - num_cells = octree_resolution +def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True): - x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) - y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) - z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) - [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) - xyz = np.stack((xs, ys, zs), axis=-1) - grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + # manually create the pointer vector + assert src.size(0) == batch.numel() - return xyz, grid_size, length + batch_size = int(batch.max()) + 1 + deg = src.new_zeros(batch_size, dtype = torch.long) + + deg.scatter_add_(0, batch, torch.ones_like(batch)) + + ptr_vec = deg.new_zeros(batch_size + 1) + torch.cumsum(deg, 0, out=ptr_vec[1:]) + + #return fps_sampling(src, ptr_vec, ratio) + sampled_indicies = [] + + for b in range(batch_size): + # start and the end of each batch + start, end = ptr_vec[b].item(), ptr_vec[b + 1].item() + # points from the point cloud + points = src[start:end] + + num_points = points.size(0) + num_samples = max(1, math.ceil(num_points * sampling_ratio)) + + selected = torch.zeros(num_samples, device = src.device, dtype = torch.long) + distances = torch.full((num_points,), float("inf"), device = src.device) + + # select a random start point + if start_random: + farthest = torch.randint(0, num_points, (1,), device = src.device) + else: + farthest = torch.tensor([0], device = src.device, dtype = torch.long) + + for i in range(num_samples): + selected[i] = farthest + centroid = points[farthest].squeeze(0) + dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance + distances = torch.minimum(distances, dist) + farthest = torch.argmax(distances) + + sampled_indicies.append(torch.arange(start, end)[selected]) + + return torch.cat(sampled_indicies, dim = 0) +class PointCrossAttention(nn.Module): + def __init__(self, + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + point_feats: int, + width: int, + heads: int, + layers: int, + fourier_embedder, + normal_pe: bool = False, + qkv_bias: bool = False, + use_ln_post: bool = True, + qk_norm: bool = True): + + super().__init__() + + self.fourier_embedder = fourier_embedder + + self.pc_size = pc_size + self.normal_pe = normal_pe + self.downsample_ratio = downsample_ratio + self.pc_sharpedge_size = pc_sharpedge_size + self.num_latents = num_latents + self.point_feats = point_feats + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) + + self.cross_attn = ResidualCrossAttentionBlock( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm + ) + + self.self_attn = None + if layers > 0: + self.self_attn = Transformer( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + layers = layers + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor): + + """ + Subsample points randomly from the point cloud (input_pc) + Further sample the subsampled points to get query_pc + take the fourier embeddings for both input and query pc + + Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc. + Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc). + More computationally efficient. + + Features are additional information for each point in the cloud + """ + + B, _, D = point_cloud.shape + + num_latents = int(self.num_latents) + + num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents + num_sharpedge_query = num_latents - num_random_query + + # Split random and sharpedge surface points + random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1) + + # assert statements + assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size" + assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size" + + input_random_pc_size = int(num_random_query * self.downsample_ratio) + random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \ + self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size) + + input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) + + if input_sharpedge_pc_size == 0: + sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device) + sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device) + + else: + sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \ + self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size) + + # concat the random and sharpedges + query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1) + input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1) + + query = self.fourier_embedder(query_pc) + data = self.fourier_embedder(input_pc) + + if self.point_feats > 0: + random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1) + + input_random_surface_features, query_random_features = \ + self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B, + input_pc_size = input_random_pc_size, idx_query = random_idx_query) + + if input_sharpedge_pc_size == 0: + input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats, + dtype = input_random_surface_features.dtype, device = point_cloud.device) + + query_sharpedge_features = torch.zeros(B, 0, self.point_feats, + dtype = query_random_features.dtype, device = point_cloud.device) + else: + + input_sharpedge_surface_features, query_sharpedge_features = \ + self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features, + batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size) + + query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1) + input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1) + + if self.normal_pe: + # apply the fourier embeddings on the first 3 dims (xyz) + input_features_pe = self.fourier_embedder(input_features[..., :3]) + query_features_pe = self.fourier_embedder(query_features[..., :3]) + # replace the first 3 dims with the new PE ones + input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1) + query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1) + + # concat at the channels dim + query = torch.cat([query, query_features], dim = -1) + data = torch.cat([data, input_features], dim = -1) + + # don't return pc_info to avoid unnecessary memory usuage + return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]) + + def forward(self, point_cloud: torch.Tensor, features: torch.Tensor): + + query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features) + + # apply projections + query = self.input_proj(query) + data = self.input_proj(data) + + # apply cross attention between query and data + latents = self.cross_attn(query, data) + + if self.self_attn is not None: + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents -class VanillaVolumeDecoder: + def subsample(self, pc, num_query, input_pc_size: int): + + """ + num_query: number of points to keep after FPS + input_pc_size: number of points to select before FPS + """ + + B, _, D = pc.shape + query_ratio = num_query / input_pc_size + + # random subsampling of points inside the point cloud + idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size] + input_pc = pc[:, idx_pc, :] + + # flatten to allow applying fps across the whole batch + flattent_input_pc = input_pc.view(B * input_pc_size, D) + + # construct a batch_down tensor to tell fps + # which points belong to which batch + N_down = int(flattent_input_pc.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + + idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio) + query_pc = flattent_input_pc[idx_query].view(B, -1, D) + + return query_pc, input_pc, idx_pc, idx_query + + def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query): + + B = batch_size + + input_surface_features = features[:, idx_pc, :] + flattent_input_features = input_surface_features.view(B * input_pc_size, -1) + query_features = flattent_input_features[idx_query].view(B, -1, + flattent_input_features.shape[-1]) + + return input_surface_features, query_features + +def normalize_mesh(mesh, scale = 0.9999): + """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]""" + + bbox = mesh.bounds + center = (bbox[1] + bbox[0]) / 2 + + max_extent = (bbox[1] - bbox[0]).max() + mesh.apply_translation(-center) + mesh.apply_scale((2 * scale) / max_extent) + + return mesh + +def sample_pointcloud(mesh, num = 200000): + """ Uniformly sample points from the surface of the mesh """ + + points, face_idx = mesh.sample(num, return_index = True) + normals = mesh.face_normals[face_idx] + return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32)) + +def detect_sharp_edges(mesh, threshold=0.985): + """Return edge indices (a, b) that lie on sharp boundaries of the mesh.""" + + V, F = mesh.vertices, mesh.faces + VN, FN = mesh.vertex_normals, mesh.face_normals + + sharp_mask = np.ones(V.shape[0]) + for i in range(3): + indices = F[:, i] + alignment = np.einsum('ij,ij->i', VN[indices], FN) + dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1) + sharp_mask[indices] = np.min(dot_stack, axis=-1) + + edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]]) + edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]]) + sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold) + + return edge_a[sharp_edges], edge_b[sharp_edges] + + +def sharp_sample_pointcloud(mesh, num = 16384): + """ Sample points preferentially from sharp edges in the mesh. """ + + edge_a, edge_b = detect_sharp_edges(mesh) + V, VN = mesh.vertices, mesh.vertex_normals + + va, vb = V[edge_a], V[edge_b] + na, nb = VN[edge_a], VN[edge_b] + + edge_lengths = np.linalg.norm(vb - va, axis=-1) + weights = edge_lengths / edge_lengths.sum() + + indices = np.searchsorted(np.cumsum(weights), np.random.rand(num)) + t = np.random.rand(num, 1) + + samples = t * va[indices] + (1 - t) * vb[indices] + normals = t * na[indices] + (1 - t) * nb[indices] + + return samples.astype(np.float32), normals.astype(np.float32) + +def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"): + """Load a surface with optional sharp-edge annotations from a trimesh mesh.""" + + import trimesh + + try: + mesh_full = trimesh.util.concatenate(mesh.dump()) + except Exception: + mesh_full = trimesh.util.concatenate(mesh) + + mesh_full = normalize_mesh(mesh_full) + + faces = mesh_full.faces + vertices = mesh_full.vertices + origin_face_count = faces.shape[0] + + mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count]) + mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:]) + + area_surface = mesh_surface.area + area_fill = mesh_fill.area + total_area = area_surface + area_fill + + sample_num = 499712 // 2 + fill_ratio = area_fill / total_area if total_area > 0 else 0 + + num_fill = int(sample_num * fill_ratio) + num_surface = sample_num - num_fill + + surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface) + fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill) + + sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num) + + def assemble_tensor(points, normals, label=None): + + data = torch.cat([points, normals], dim=1).half().to(device) + + if label is not None: + label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device) + data = torch.cat([data, label_tensor], dim=1) + + return data + + surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0), + torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0), + label = 0 if sharpedge_flag else None) + + sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals), + label = 1 if sharpedge_flag else None) + + rng = np.random.default_rng() + + surface = surface[rng.choice(surface.shape[0], num_points, replace = False)] + sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)] + + full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0) + + return full + +class SharpEdgeSurfaceLoader: + """ Load mesh surface and sharp edge samples. """ + + def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192): + + self.num_uniform_points = num_uniform_points + self.num_sharp_points = num_sharp_points + self.total_points = num_uniform_points + num_sharp_points + + def __call__(self, mesh_input, device = "cuda"): + mesh = self._load_mesh(mesh_input) + return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device) + + @staticmethod + def _load_mesh(mesh_input): + import trimesh + + if isinstance(mesh_input, str): + mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True) + else: + mesh = mesh_input + + if isinstance(mesh, trimesh.Scene): + combined = None + for obj in mesh.geometry.values(): + combined = obj if combined is None else combined + obj + return combined + + return mesh + +class DiagonalGaussianDistribution: + def __init__(self, params: torch.Tensor, feature_dim: int = -1): + + # divide quant channels (8) into mean and log variance + self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + + def sample(self): + + eps = torch.randn_like(self.std) + z = self.mean + eps * self.std + + return z + +################################################ +# Volume Decoder +################################################ + +class VanillaVolumeDecoder(): @torch.no_grad() - def __call__( - self, - latents: torch.FloatTensor, - geo_decoder: Callable, - bounds: Union[Tuple[float], List[float], float] = 1.01, - num_chunks: int = 10000, - octree_resolution: int = None, - enable_pbar: bool = True, - **kwargs, - ): - device = latents.device - dtype = latents.dtype - batch_size = latents.shape[0] + def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01, + num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs): - # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] - bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) - xyz_samples, grid_size, length = generate_dense_grid_points( - bbox_min=bbox_min, - bbox_max=bbox_max, - octree_resolution=octree_resolution, - indexing="ij" - ) - xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:]) + + x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32) + y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32) + z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32) + + [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij") + xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3) + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] - # 2. latents to 3d volume batch_logits = [] - for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): - chunk_queries = xyz_samples[start: start + num_chunks, :] - chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) - logits = geo_decoder(queries=chunk_queries, latents=latents) + + chunk_queries = xyz[start: start + num_chunks, :] + chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1) + logits = geo_decoder(queries = chunk_queries, latents = latents) batch_logits.append(logits) - grid_logits = torch.cat(batch_logits, dim=1) - grid_logits = grid_logits.view((batch_size, *grid_size)).float() + grid_logits = torch.cat(batch_logits, dim = 1) + grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float() return grid_logits - class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: @@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module): else: return x - class CrossAttentionProcessor: def __call__(self, attn, q, k, v): out = comfy.ops.scaled_dot_product_attention(q, k, v) return out - class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @@ -232,38 +607,41 @@ class MLP(nn.Module): def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) - class QKVMultiheadCrossAttention(nn.Module): def __init__( self, - *, heads: int, + n_data = None, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads + self.n_data = n_data self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.attn_processor = CrossAttentionProcessor() - def forward(self, q, kv): + _, n_ctx, _ = q.shape bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = self.attn_processor(self, q, k, v) - out = out.transpose(1, 2).reshape(bs, n_ctx, -1) - return out + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] + out = F.scaled_dot_product_attention(q, k, v) + + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + + return out class MultiheadCrossAttention(nn.Module): def __init__( @@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module): x = self.c_proj(x) return x - class ResidualCrossAttentionBlock(nn.Module): def __init__( self, @@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module): drop_path_rate: float = 0.0 ): super().__init__() - self.width = width - self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadAttention( @@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module): self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = ops.Linear(width * downsample_ratio, width) - if self.enable_ln_post == False: + if not self.enable_ln_post: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, @@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module): class ShapeVAE(nn.Module): def __init__( - self, - *, - embed_dim: int, - width: int, - heads: int, - num_decoder_layers: int, - geo_decoder_downsample_ratio: int = 1, - geo_decoder_mlp_expand_ratio: int = 4, - geo_decoder_ln_post: bool = True, - num_freqs: int = 8, - include_pi: bool = True, - qkv_bias: bool = True, - qk_norm: bool = False, - label_type: str = "binary", - drop_path_rate: float = 0.0, - scale_factor: float = 1.0, + self, + *, + num_latents: int = 4096, + embed_dim: int = 64, + width: int = 1024, + heads: int = 16, + num_decoder_layers: int = 16, + num_encoder_layers: int = 8, + pc_size: int = 81920, + pc_sharpedge_size: int = 0, + point_feats: int = 4, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + drop_path_rate: float = 0.0, + include_pi: bool = False, + scale_factor: float = 1.0039506158752403, + label_type: str = "binary", ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + self.encoder = PointCrossAttention(layers = num_encoder_layers, + num_latents = num_latents, + downsample_ratio = downsample_ratio, + heads = heads, + pc_size = pc_size, + width = width, + point_feats = point_feats, + fourier_embedder = self.fourier_embedder, + pc_sharpedge_size = pc_sharpedge_size) + self.post_kl = ops.Linear(embed_dim, width) self.transformer = Transformer( @@ -583,5 +975,14 @@ class ShapeVAE(nn.Module): grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) return grid_logits.movedim(-2, -1) - def encode(self, x): - return None + def encode(self, surface): + + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents = self.encoder(pc, feats) + + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feature_dim = -1) + + latents = posterior.sample() + + return latents diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py new file mode 100644 index 000000000..48575bb3c --- /dev/null +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -0,0 +1,658 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, operations, device, dtype): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + + if gate.device.type == "mps": + return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype) + + return F.gelu(gate) + + def forward(self, hidden_states): + + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + + return hidden_states + +class FeedForward(nn.Module): + + def __init__(self, dim: int, dim_out = None, mult: int = 4, + dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None): + + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + + dim_out = dim_out if dim_out is not None else dim + + act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype) + + self.net = nn.ModuleList([]) + self.net.append(act_fn) + + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class AddAuxLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, loss): + # do nothing in forward (no computation) + ctx.requires_aux_loss = loss.requires_grad + ctx.dtype = loss.dtype + + return x + + @staticmethod + def backward(ctx, grad_output): + # add the aux loss gradients + grad_loss = None + # put the aux grad the same as the main grad loss + # aux grad contributes equally + if ctx.requires_aux_loss: + grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device) + + return grad_output, grad_loss + +class MoEGate(nn.Module): + + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None): + + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + + self.alpha = aux_loss_alpha + + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + # flatten hidden states + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # get logits and pass it to softmax + logits = F.linear(hidden_states, self.weight, bias = None) + scores = logits.softmax(dim = -1) + + topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + + # used bincount instead of one hot encoding + counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float() + ce = counts / topk_idx.numel() # normalized expert usage + + # mean expert score + Pi = scores_for_aux.mean(0) + + # expert balance loss + aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + +class MoEBlock(nn.Module): + def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0, + ff_inner_dim: int = None, operations = None, device = None, dtype = None): + super().__init__() + + self.moe_top_k = moe_top_k + self.num_experts = num_experts + + self.experts = nn.ModuleList([ + FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + for _ in range(num_experts) + ]) + + self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype) + self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + + def forward(self, hidden_states) -> torch.Tensor: + + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + + if self.training: + + hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0) + y = torch.empty_like(hidden_states, dtype = hidden_states.dtype) + + for i, expert in enumerate(self.experts): + tmp = expert(hidden_states[flat_topk_idx == i]) + y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) + + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1) + y = y.view(*orig_shape) + + y = AddAuxLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape) + + y = y + self.shared_experts(identity) + + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + + # no need for .numpy().cpu() here + tokens_per_expert = flat_expert_indices.bincount().cumsum(0) + token_idxs = idxs // self.moe_top_k + + for i, end_idx in enumerate(tokens_per_expert): + + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + + if start_idx == end_idx: + continue + + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_ + # + avoid dtype conversion + expert_cache.index_add_(0, exp_token_idx, expert_out) + + return expert_cache + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0, + scale: float = 1.0, max_period: int = 10000): + super().__init__() + + self.num_channels = num_channels + half_dim = num_channels // 2 + + # precompute the “inv_freq” vector once + exponent = -math.log(max_period) * torch.arange( + half_dim, dtype=torch.float32 + ) / (half_dim - downscale_freq_shift) + + inv_freq = torch.exp(exponent) + + # pad + if num_channels % 2 == 1: + # we’ll pad a zero at the end of the cos-half + inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)]) + + # register to buffer so it moves with the device + self.register_buffer("inv_freq", inv_freq, persistent = False) + self.scale = scale + + def forward(self, timesteps: torch.Tensor): + + x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0) + + + # fused CUDA kernels for sin and cos + sin_emb = x.sin() + cos_emb = x.cos() + + emb = torch.cat([sin_emb, cos_emb], dim = 1) + + # scale factor + if self.scale != 1.0: + emb = emb * self.scale + + # If we padded inv_freq for odd, emb is already wide enough; otherwise: + if emb.shape[1] > self.num_channels: + emb = emb[:, :self.num_channels] + + return emb + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None): + super().__init__() + + self.mlp = nn.Sequential( + operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype), + nn.GELU(), + operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype) + + self.time_embed = Timesteps(hidden_size) + + def forward(self, timesteps, condition): + + timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype) + + if condition is not None: + cond_embed = self.cond_proj(condition) + timestep_embed = timestep_embed + cond_embed + + time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device)) + + # for broadcasting with image tokens + return time_conditioned.unsqueeze(1) + +class MLP(nn.Module): + def __init__(self, *, width: int, operations = None, device = None, dtype = None): + super().__init__() + self.width = width + self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype) + self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + +class CrossAttention(nn.Module): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + use_fp16: bool = False, + operations = None, + dtype = None, + device = None, + **kwargs, + ): + super().__init__() + self.qdim = qdim + self.kdim = kdim + + self.num_heads = num_heads + self.head_dim = self.qdim // num_heads + + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype) + + def forward(self, x, y): + + b, s1, _ = x.shape + _, s2, _ = y.shape + + y = y.to(next(self.to_k.parameters()).dtype) + + q = self.to_q(x) + k = self.to_k(y) + v = self.to_v(y) + + kv = torch.cat((k, v), dim=-1) + split_size = kv.shape[-1] // self.num_heads // 2 + + kv = kv.view(1, -1, self.num_heads, split_size * 2) + k, v = torch.split(kv, split_size, dim=-1) + + q = q.view(b, s1, self.num_heads, self.head_dim) + k = k.view(b, s2, self.num_heads, self.head_dim) + v = v.reshape(b, s2, self.num_heads * self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + x = optimized_attention( + q.reshape(b, s1, self.num_heads * self.head_dim), + k.reshape(b, s2, self.num_heads * self.head_dim), + v, + heads=self.num_heads, + ) + + out = self.out_proj(x) + + return out + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads, + qkv_bias = True, + qk_norm = False, + norm_layer = nn.LayerNorm, + use_fp16: bool = False, + operations = None, + device = None, + dtype = None + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = self.dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype) + + def forward(self, x): + B, N, _ = x.shape + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + qkv_combined = torch.cat((query, key, value), dim=-1) + split_size = qkv_combined.shape[-1] // self.num_heads // 3 + + qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + + query = query.reshape(B, N, self.num_heads, self.head_dim) + key = key.reshape(B, N, self.num_heads, self.head_dim) + value = value.reshape(B, N, self.num_heads * self.head_dim) + + query = self.q_norm(query) + key = self.k_norm(key) + + x = optimized_attention( + query.reshape(B, N, self.num_heads * self.head_dim), + key.reshape(B, N, self.num_heads * self.head_dim), + value, + heads=self.num_heads, + ) + + x = self.out_proj(x) + return x + +class HunYuanDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + c_emb_size, + num_heads, + text_states_dim=1024, + qk_norm=False, + norm_layer=nn.LayerNorm, + qk_norm_layer=nn.RMSNorm, + qkv_bias=True, + skip_connection=True, + timested_modulate=False, + use_moe: bool = False, + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + operations = None, + device = None, dtype = None + ): + super().__init__() + + # eps can't be 1e-6 in fp16 mode because of numerical stability issues + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations) + + self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.timested_modulate = timested_modulate + if self.timested_modulate: + self.default_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype) + ) + + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + + self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + if skip_connection: + self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype) + else: + self.skip_linear = None + + self.use_moe = use_moe + + if self.use_moe: + self.moe = MoEBlock( + hidden_size, + num_experts = num_experts, + moe_top_k = moe_top_k, + dropout = 0.0, + ff_inner_dim = int(hidden_size * 4.0), + device = device, dtype = dtype, + operations = operations + ) + else: + self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype) + + def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None): + + if self.skip_linear is not None: + combined = torch.cat([skip_tensor, hidden_states], dim=-1) + hidden_states = self.skip_linear(combined) + hidden_states = self.skip_norm(hidden_states) + + # self attention + if self.timested_modulate: + modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1) + hidden_states = hidden_states + modulation_shift + + self_attn_out = self.attn1(self.norm1(hidden_states)) + hidden_states = hidden_states + self_attn_out + + # cross attention + hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states) + + # MLP Layer + mlp_input = self.norm3(hidden_states) + + if self.use_moe: + hidden_states = hidden_states + self.moe(mlp_input) + else: + hidden_states = hidden_states + self.mlp(mlp_input) + + return hidden_states + +class FinalLayer(nn.Module): + + def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None): + super().__init__() + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype) + + def forward(self, x): + x = self.norm_final(x) + x = x[:, 1:] + x = self.linear(x) + return x + +class HunYuanDiTPlain(nn.Module): + + # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 2048, + context_dim: int = 1024, + depth: int = 21, + num_heads: int = 16, + qk_norm: bool = True, + qkv_bias: bool = False, + num_moe_layers: int = 6, + guidance_cond_proj_dim = 2048, + norm_type = 'layer', + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + dtype = None, + device = None, + operations = None, + **kwargs + ): + + self.dtype = dtype + + super().__init__() + + self.depth = depth + + self.in_channels = in_channels + self.out_channels = in_channels + + self.num_heads = num_heads + self.hidden_size = hidden_size + + norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm + qk_norm = operations.RMSNorm + + self.context_dim = context_dim + self.guidance_cond_proj_dim = guidance_cond_proj_dim + + self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype) + self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations) + + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + text_states_dim=context_dim, + qk_norm=qk_norm, + norm_layer = norm, + qk_norm_layer = qk_norm, + skip_connection=layer > depth // 2, + qkv_bias=qkv_bias, + use_moe=True if depth - layer <= num_moe_layers else False, + num_experts=num_experts, + moe_top_k=moe_top_k, + use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + for layer in range(depth) + ]) + + self.depth = depth + + self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype) + + def forward(self, x, t, context, transformer_options = {}, **kwargs): + + x = x.movedim(-1, -2) + uncond_emb, cond_emb = context.chunk(2, dim = 0) + + context = torch.cat([cond_emb, uncond_emb], dim = 0) + main_condition = context + + t = 1.0 - t + + time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond')) + + x = x.to(dtype = next(self.x_embedder.parameters()).dtype) + x_embedded = self.x_embedder(x) + + combined = torch.cat([time_embedded, x_embedded], dim=1) + + def block_wrap(args): + return block( + args["x"], + args["t"], + args["cond"], + skip_tensor=args.get("skip"),) + + skip_stack = [] + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for idx, block in enumerate(self.blocks): + if idx <= self.depth // 2: + skip_input = None + else: + skip_input = skip_stack.pop() + + if ("block", idx) in blocks_replace: + + combined = blocks_replace[("block", idx)]( + { + "x": combined, + "t": time_embedded, + "cond": main_condition, + "skip": skip_input, + }, + {"original_block": block_wrap}, + ) + else: + combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input) + + if idx < self.depth // 2: + skip_stack.append(combined) + + output = self.final_layer(combined) + output = output.movedim(-2, -1) * (-1.0) + + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 56a6798be..39a3344bc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +import comfy.ldm.hunyuan3dv2_1 +import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class Hunyuan3Dv2_1(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9f3ab64df..75552ede9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -400,6 +400,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 + + dit_config = {} + dit_config["image_model"] = "hunyuan3d2_1" + dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["context_dim"] = 1024 + dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") + dit_config["qkv_bias"] = False + dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys + return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" diff --git a/comfy/sd.py b/comfy/sd.py index bb5d61fb3..014f797ca 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -446,17 +446,29 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 - ln_post = "geo_decoder.ln_post.weight" in sd - inner_size = sd["geo_decoder.output_proj.weight"].shape[1] - downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size - mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO - self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO - ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} - self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + + def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): + batch, num_tokens, hidden_dim = shape + dtype_size = model_management.dtype_size(dtype) + + total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers) + return total_mem + + # better memory estimations + self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) @@ -1046,6 +1058,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model = None model_patcher = None + if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]): + from collections import OrderedDict + import gc + + merged_sd = OrderedDict() + + for k, v in sd["model"].items(): + merged_sd[f"model.{k}"] = v + + for k, v in sd["vae"].items(): + merged_sd[f"vae.{k}"] = v + + for key, value in sd["conditioner"].items(): + merged_sd[f"conditioner.{key}"] = value + + sd = merged_sd + + del merged_sd + gc.collect() + torch.cuda.empty_cache() + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 76260de00..75dad277d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1128,6 +1128,17 @@ class Hunyuan3Dv2(supported_models_base.BASE): def clip_target(self, state_dict={}): return None +class Hunyuan3Dv2_1(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2_1", + } + + latent_format = latent_formats.Hunyuan3Dv2_1 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2_1(self, device = device) + return out + class Hunyuan3Dv2mini(Hunyuan3Dv2): unet_config = { "image_model": "hunyuan3d2", @@ -1285,6 +1296,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 51e45336a..f6e71e0a8 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -8,13 +8,16 @@ import folder_paths import comfy.model_management from comfy.cli_args import args - class EmptyLatentHunyuan3Dv2: @classmethod def INPUT_TYPES(s): - return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} + return { + "required": { + "resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + } + } + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) return ({"samples": latent, "type": "hunyuan3dv2"}, ) - class Hunyuan3Dv2Conditioning: @classmethod def INPUT_TYPES(s): @@ -81,7 +83,6 @@ class VOXEL: def __init__(self, data): self.data = data - class VAEDecodeHunyuan3D: @classmethod def INPUT_TYPES(s): @@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D: voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) return (voxels, ) - def voxel_to_mesh(voxels, threshold=0.5, device=None): if device is None: device = torch.device("cpu") @@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] ], device=device) - corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) - for c, (dz, dy, dx) in enumerate(corner_offsets): - corner_values[:, c] = padded[ - cell_positions[:, 0] + dz, - cell_positions[:, 1] + dy, - cell_positions[:, 2] + dx - ] + pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0) + z_idx, y_idx, x_idx = pos.unbind(-1) + corner_values = padded[z_idx, y_idx, x_idx] corner_signs = corner_values > threshold has_inside = torch.any(corner_signs, dim=1) diff --git a/nodes.py b/nodes.py index 6c2f9dd14..1afe5601a 100644 --- a/nodes.py +++ b/nodes.py @@ -998,20 +998,31 @@ class CLIPVisionLoader: class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",), - "crop": (["center", "none"],) - }} + return { + "required": { + "clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",), + "crop": (["center", "none", "recenter"],), + }, + "optional": { + "border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}), + } + } + RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip_vision, image, crop): - crop_image = True - if crop != "center": - crop_image = False - output = clip_vision.encode_image(image, crop=crop_image) + def encode(self, clip_vision, image, crop, border_ratio): + crop_image = crop == "center" + + if crop == "recenter": + crop_image = True + else: + border_ratio = None + + output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio) return (output,) class StyleModelLoader: diff --git a/requirements.txt b/requirements.txt index 3008a5dc3..564fa6e23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ kornia>=0.7.1 spandrel soundfile pydantic~=2.0 -pydantic-settings~=2.0 +pydantic-settings~=2.0 \ No newline at end of file From c9ebe70072213a875ffbe40cc1b36820b2005211 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:39:02 -0700 Subject: [PATCH 163/325] Some changes to the previous hunyuan PR. (#9725) --- comfy/clip_vision.py | 225 +------------------------------------------ comfy/sd.py | 21 ---- nodes.py | 29 ++---- requirements.txt | 2 +- 4 files changed, 14 insertions(+), 263 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 4bc640e8b..447b1ce4a 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -17,227 +17,10 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) - -def cubic_kernel(x, a: float = -0.75): - absx = x.abs() - absx2 = absx ** 2 - absx3 = absx ** 3 - - w = (a + 2) * absx3 - (a + 3) * absx2 + 1 - w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a - - return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x))) - -def get_indices_weights(in_size, out_size, scale): - # OpenCV-style half-pixel mapping - x = torch.arange(out_size, dtype=torch.float32) - x = (x + 0.5) / scale - 0.5 - - x0 = x.floor().long() - dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3)) - - weights = cubic_kernel(dx) - weights = weights / weights.sum(dim=1, keepdim=True) - - indices = x0.unsqueeze(1) + torch.arange(-1, 3) - indices = indices.clamp(0, in_size - 1) - - return indices, weights - -def resize_cubic_1d(x, out_size, dim): - b, c, h, w = x.shape - in_size = h if dim == 2 else w - scale = out_size / in_size - - indices, weights = get_indices_weights(in_size, out_size, scale) - - if dim == 2: - x = x.permute(0, 1, 3, 2) - x = x.reshape(-1, h) - else: - x = x.reshape(-1, w) - - gathered = x[:, indices] - out = (gathered * weights.unsqueeze(0)).sum(dim=2) - - if dim == 2: - out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2) - else: - out = out.reshape(b, c, h, out_size) - - return out - -def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor: - """ - Resize image using OpenCV-equivalent INTER_CUBIC interpolation. - Implemented in pure PyTorch - """ - - if img.ndim == 3: - img = img.unsqueeze(0) - - img = img.permute(0, 3, 1, 2) - - out_h, out_w = size - img = resize_cubic_1d(img, out_h, dim=2) - img = resize_cubic_1d(img, out_w, dim=3) - return img - -def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor: - # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch - original_shape = img.shape - is_hwc = False - - if img.ndim == 3: - if img.shape[0] <= 4: - img = img.unsqueeze(0) - else: - is_hwc = True - img = img.permute(2, 0, 1).unsqueeze(0) - elif img.ndim == 4: - pass - else: - raise ValueError("Expected image with 3 or 4 dims.") - - B, C, H, W = img.shape - out_h, out_w = size - scale_y = H / out_h - scale_x = W / out_w - - device = img.device - - # compute the grid boundries - y_start = torch.arange(out_h, device=device).float() * scale_y - y_end = y_start + scale_y - x_start = torch.arange(out_w, device=device).float() * scale_x - x_end = x_start + scale_x - - # for each output pixel, we will compute the range for it - y_start_int = torch.floor(y_start).long() - y_end_int = torch.ceil(y_end).long() - x_start_int = torch.floor(x_start).long() - x_end_int = torch.ceil(x_end).long() - - # We will build the weighted sums by iterating over contributing input pixels once - output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device) - area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device) - - max_kernel_h = int(torch.max(y_end_int - y_start_int).item()) - max_kernel_w = int(torch.max(x_end_int - x_start_int).item()) - - for dy in range(max_kernel_h): - for dx in range(max_kernel_w): - # compute the weights for this offset for all output pixels - - y_idx = y_start_int.unsqueeze(1) + dy - x_idx = x_start_int.unsqueeze(0) + dx - - # clamp indices to image boundaries - y_idx_clamped = torch.clamp(y_idx, 0, H - 1) - x_idx_clamped = torch.clamp(x_idx, 0, W - 1) - - # compute weights by broadcasting - y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0) - x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0) - - weight = (y_weight * x_weight) - - y_expand = y_idx_clamped.expand(out_h, out_w) - x_expand = x_idx_clamped.expand(out_h, out_w) - - - pixels = img[:, :, y_expand, x_expand] - - # unsqueeze to broadcast - w = weight.unsqueeze(0).unsqueeze(0) - - output += pixels * w - area += weight - - # Normalize by area - output /= area.unsqueeze(0).unsqueeze(0) - - if is_hwc: - return output[0].permute(1, 2, 0) - elif img.shape[0] == 1 and original_shape[0] <= 4: - return output[0] - else: - return output - -def recenter(image, border_ratio: float = 0.2): - - if image.shape[-1] == 4: - mask = image[..., 3] - else: - mask = torch.ones_like(image[..., 0:1]) * 255 - image = torch.concatenate([image, mask], axis=-1) - mask = mask[..., 0] - - H, W, C = image.shape - - size = max(H, W) - result = torch.zeros((size, size, C), dtype = torch.uint8) - - # as_tuple to match numpy behaviour - x_coords, y_coords = torch.nonzero(mask, as_tuple=True) - - y_min, y_max = y_coords.min(), y_coords.max() - x_min, x_max = x_coords.min(), x_coords.max() - - h = x_max - x_min - w = y_max - y_min - - if h == 0 or w == 0: - raise ValueError('input image is empty') - - desired_size = int(size * (1 - border_ratio)) - scale = desired_size / max(h, w) - - h2 = int(h * scale) - w2 = int(w * scale) - - x2_min = (size - h2) // 2 - x2_max = x2_min + h2 - - y2_min = (size - w2) // 2 - y2_max = y2_min + w2 - - # note: opencv takes columns first (opposite to pytorch and numpy that take the row first) - result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2)) - - bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255 - - mask = result[..., 3:].to(torch.float32) / 255 - result = result[..., :3] * mask + bg * (1 - mask) - - mask = mask * 255 - result = result.clip(0, 255).to(torch.uint8) - mask = mask.clip(0, 255).to(torch.uint8) - - return result - -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], - crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512): - - if border_ratio is not None: - - image = (image * 255).clamp(0, 255).to(torch.uint8) - image = [recenter(img, border_ratio = border_ratio) for img in image] - - image = torch.stack(image, dim = 0) - image = resize_cubic(image, size = (recenter_size, recenter_size)) - - image = image / 255 * 2 - 1 - low, high = value_range - - image = (image - low) / (high - low) - image = image.permute(0, 2, 3, 1) - +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): if crop: @@ -246,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True) + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] @@ -288,9 +71,9 @@ class ClipVisionModel(): def get_sd(self): return self.model.state_dict() - def encode_image(self, image, crop=True, border_ratio: float = None): + def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() diff --git a/comfy/sd.py b/comfy/sd.py index 014f797ca..be5aa8dc8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1058,27 +1058,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model = None model_patcher = None - if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]): - from collections import OrderedDict - import gc - - merged_sd = OrderedDict() - - for k, v in sd["model"].items(): - merged_sd[f"model.{k}"] = v - - for k, v in sd["vae"].items(): - merged_sd[f"vae.{k}"] = v - - for key, value in sd["conditioner"].items(): - merged_sd[f"conditioner.{key}"] = value - - sd = merged_sd - - del merged_sd - gc.collect() - torch.cuda.empty_cache() - diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) diff --git a/nodes.py b/nodes.py index 1afe5601a..6c2f9dd14 100644 --- a/nodes.py +++ b/nodes.py @@ -998,31 +998,20 @@ class CLIPVisionLoader: class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): - return { - "required": { - "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",), - "crop": (["center", "none", "recenter"],), - }, - "optional": { - "border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}), - } - } - + return {"required": { "clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",), + "crop": (["center", "none"],) + }} RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip_vision, image, crop, border_ratio): - crop_image = crop == "center" - - if crop == "recenter": - crop_image = True - else: - border_ratio = None - - output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio) + def encode(self, clip_vision, image, crop): + crop_image = True + if crop != "center": + crop_image = False + output = clip_vision.encode_image(image, crop=crop_image) return (output,) class StyleModelLoader: diff --git a/requirements.txt b/requirements.txt index 564fa6e23..3008a5dc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ kornia>=0.7.1 spandrel soundfile pydantic~=2.0 -pydantic-settings~=2.0 \ No newline at end of file +pydantic-settings~=2.0 From 3493b9cb1f9a9a66b1b86ed908cf87bc382b647a Mon Sep 17 00:00:00 2001 From: Arjan Singh <1598641+arjansingh@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:32:25 -0700 Subject: [PATCH 164/325] fix: add cache headers for images (#9560) --- middleware/__init__.py | 1 + middleware/cache_middleware.py | 52 ++++ server.py | 11 +- tests-unit/server_test/test_cache_control.py | 255 +++++++++++++++++++ 4 files changed, 311 insertions(+), 8 deletions(-) create mode 100644 middleware/__init__.py create mode 100644 middleware/cache_middleware.py create mode 100644 tests-unit/server_test/test_cache_control.py diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 000000000..2d7c7c3a9 --- /dev/null +++ b/middleware/__init__.py @@ -0,0 +1 @@ +"""Server middleware modules""" diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py new file mode 100644 index 000000000..374ef7934 --- /dev/null +++ b/middleware/cache_middleware.py @@ -0,0 +1,52 @@ +"""Cache control middleware for ComfyUI server""" + +from aiohttp import web +from typing import Callable, Awaitable + +# Time in seconds +ONE_HOUR: int = 3600 +ONE_DAY: int = 86400 +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +@web.middleware +async def cache_control( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]] +) -> web.Response: + """Cache control middleware that sets appropriate cache headers based on file type and response status""" + response: web.Response = await handler(request) + + if ( + request.path.endswith(".js") + or request.path.endswith(".css") + or request.path.endswith("index.json") + ): + response.headers.setdefault("Cache-Control", "no-cache") + return response + + # Early return for non-image files - no cache headers needed + if not request.path.lower().endswith(IMG_EXTENSIONS): + return response + + # Handle image files + if response.status == 404: + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}") + elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308): + # Success responses and permanent redirects - cache for 1 day + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}") + elif response.status in (302, 303, 307): + # Temporary redirects - no cache + response.headers.setdefault("Cache-Control", "no-cache") + # Note: 304 Not Modified falls through - no cache headers set + + return response diff --git a/server.py b/server.py index 3d323eaf8..43816a8cd 100644 --- a/server.py +++ b/server.py @@ -39,20 +39,15 @@ from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes +# Import cache control middleware +from middleware.cache_middleware import cache_control + async def send_socket_catch_exception(function, message): try: await function(message) except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: logging.warning("send error: {}".format(err)) -@web.middleware -async def cache_control(request: web.Request, handler): - response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): - response.headers.setdefault('Cache-Control', 'no-cache') - return response - - @web.middleware async def compress_body(request: web.Request, handler): accept_encoding = request.headers.get("Accept-Encoding", "") diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py new file mode 100644 index 000000000..8de59125a --- /dev/null +++ b/tests-unit/server_test/test_cache_control.py @@ -0,0 +1,255 @@ +"""Tests for server cache control middleware""" + +import pytest +from aiohttp import web +from aiohttp.test_utils import make_mocked_request +from typing import Dict, Any + +from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS + +pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests + +# Test configuration data +CACHE_SCENARIOS = [ + # Image file scenarios + { + "name": "image_200_status", + "path": "/test.jpg", + "status": 200, + "expected_cache": f"public, max-age={ONE_DAY}", + "should_have_header": True, + }, + { + "name": "image_404_status", + "path": "/missing.jpg", + "status": 404, + "expected_cache": f"public, max-age={ONE_HOUR}", + "should_have_header": True, + }, + # JavaScript/CSS scenarios + { + "name": "js_no_cache", + "path": "/script.js", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "css_no_cache", + "path": "/styles.css", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "index_json_no_cache", + "path": "/api/index.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + # Non-matching files + { + "name": "html_no_header", + "path": "/index.html", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "txt_no_header", + "path": "/data.txt", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "api_endpoint_no_header", + "path": "/api/endpoint", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "pdf_no_header", + "path": "/file.pdf", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, +] + +# Status code scenarios for images +IMAGE_STATUS_SCENARIOS = [ + # Success statuses get long cache + {"status": 200, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 201, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 202, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 204, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 206, "expected": f"public, max-age={ONE_DAY}"}, + # Permanent redirects get long cache + {"status": 301, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 308, "expected": f"public, max-age={ONE_DAY}"}, + # Temporary redirects get no cache + {"status": 302, "expected": "no-cache"}, + {"status": 303, "expected": "no-cache"}, + {"status": 307, "expected": "no-cache"}, + # 404 gets short cache + {"status": 404, "expected": f"public, max-age={ONE_HOUR}"}, +] + +# Case sensitivity test paths +CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"] + +# Edge case test paths +EDGE_CASE_PATHS = [ + { + "name": "query_strings_ignored", + "path": "/image.jpg?v=123&size=large", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "multiple_dots_in_path", + "path": "/image.min.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "nested_paths_with_images", + "path": "/static/images/photo.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, +] + + +class TestCacheControl: + """Test cache control middleware functionality""" + + @pytest.fixture + def status_handler_factory(self): + """Create a factory for handlers that return specific status codes""" + + def factory(status: int, headers: Dict[str, str] = None): + async def handler(request): + return web.Response(status=status, headers=headers or {}) + + return handler + + return factory + + @pytest.fixture + def mock_handler(self, status_handler_factory): + """Create a mock handler that returns a response with 200 status""" + return status_handler_factory(200) + + @pytest.fixture + def handler_with_existing_cache(self, status_handler_factory): + """Create a handler that returns response with existing Cache-Control header""" + return status_handler_factory(200, {"Cache-Control": "max-age=3600"}) + + async def assert_cache_header( + self, + response: web.Response, + expected_cache: str = None, + should_have_header: bool = True, + ): + """Helper to assert cache control headers""" + if should_have_header: + assert "Cache-Control" in response.headers + if expected_cache: + assert response.headers["Cache-Control"] == expected_cache + else: + assert "Cache-Control" not in response.headers + + # Parameterized tests + @pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"]) + async def test_cache_control_scenarios( + self, scenario: Dict[str, Any], status_handler_factory + ): + """Test various cache control scenarios""" + handler = status_handler_factory(scenario["status"]) + request = make_mocked_request("GET", scenario["path"]) + response = await cache_control(request, handler) + + assert response.status == scenario["status"] + await self.assert_cache_header( + response, scenario["expected_cache"], scenario["should_have_header"] + ) + + @pytest.mark.parametrize("ext", IMG_EXTENSIONS) + async def test_all_image_extensions(self, ext: str, mock_handler): + """Test all defined image extensions are handled correctly""" + request = make_mocked_request("GET", f"/image{ext}") + response = await cache_control(request, mock_handler) + + assert response.status == 200 + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize( + "status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}" + ) + async def test_image_status_codes( + self, status_scenario: Dict[str, Any], status_handler_factory + ): + """Test different status codes for image requests""" + handler = status_handler_factory(status_scenario["status"]) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + assert response.status == status_scenario["status"] + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == status_scenario["expected"] + + @pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS) + async def test_case_insensitive_image_extension(self, path: str, mock_handler): + """Test that image extensions are matched case-insensitively""" + request = make_mocked_request("GET", path) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"]) + async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler): + """Test edge cases like query strings, nested paths, etc.""" + request = make_mocked_request("GET", edge_case["path"]) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == edge_case["expected"] + + # Header preservation tests (special cases not covered by parameterization) + async def test_js_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .js files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/script.js") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_css_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .css files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/styles.css") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_image_preserves_existing_headers(self, status_handler_factory): + """Test that image cache headers preserve existing Cache-Control""" + handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"}) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "private, no-cache" + + async def test_304_not_modified_inherits_cache(self, status_handler_factory): + """Test that 304 Not Modified doesn't set cache headers for images""" + handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"}) + request = make_mocked_request("GET", "/not-modified.jpg") + response = await cache_control(request, handler) + + assert response.status == 304 + # Should preserve existing cache header, not override + assert response.headers["Cache-Control"] == "max-age=7200" From 2ee7879a0bdbf507bfd26f8b36eca2fef147c29d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:57:35 -0700 Subject: [PATCH 165/325] Fix lowvram issues with hunyuan3d 2.1 (#9735) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index 48575bb3c..ca1a83001 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management class GELU(nn.Module): @@ -88,7 +89,7 @@ class MoEGate(nn.Module): hidden_states = hidden_states.view(-1, hidden_states.size(-1)) # get logits and pass it to softmax - logits = F.linear(hidden_states, self.weight, bias = None) + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None) scores = logits.softmax(dim = -1) topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) @@ -255,7 +256,7 @@ class TimestepEmbedder(nn.Module): cond_embed = self.cond_proj(condition) timestep_embed = timestep_embed + cond_embed - time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device)) + time_conditioned = self.mlp(timestep_embed) # for broadcasting with image tokens return time_conditioned.unsqueeze(1) From ea6cdd2631fbca6ed81b95796150c32c9a029f0d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:05:05 -0700 Subject: [PATCH 166/325] Print all fast options in --help (#9737) --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72eeaea9a..cc1f12482 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -145,7 +145,7 @@ class PerformanceFeature(enum.Enum): CublasOps = "cublas_ops" AutoTune = "autotune" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") From 27a0fcccc376fef6f035ed97664db8aa7e2e6117 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 6 Sep 2025 20:25:22 -0700 Subject: [PATCH 167/325] Enable bf16 VAE on RDNA4. (#9746) --- comfy/model_management.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d08aee1fe..17516b6ed 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -289,6 +289,21 @@ def is_amd(): return True return False +def amd_min_version(device=None, min_rdna_version=0): + if not is_amd(): + return False + + arch = torch.cuda.get_device_properties(device).gcnArchName + if arch.startswith('gfx') and len(arch) == 7: + try: + cmp_rdna_version = int(arch[4]) + 2 + except: + cmp_rdna_version = 0 + if cmp_rdna_version >= min_rdna_version: + return True + + return False + MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): MIN_WEIGHT_MEMORY_RATIO = 0.0 @@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]): # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3 - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + # also a problem on RDNA4 except fp32 is also slow there. + # This is due to large bf16 convolutions being extremely slow. + if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): return d return torch.float32 From bcbd7884e3af5ee8b6ab848da2a3123f247d6114 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 6 Sep 2025 21:29:38 -0700 Subject: [PATCH 168/325] Don't enable pytorch attention on AMD if triton isn't available. (#9747) --- comfy/model_management.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 17516b6ed..cb6580f73 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,6 +22,7 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys +import importlib import platform import weakref import gc @@ -336,12 +337,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 - ENABLE_PYTORCH_ATTENTION = True -# if torch_version_numeric >= (2, 8): -# if any((a in arch) for a in ["gfx1201"]): -# ENABLE_PYTORCH_ATTENTION = True + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + ENABLE_PYTORCH_ATTENTION = True +# if torch_version_numeric >= (2, 8): +# if any((a in arch) for a in ["gfx1201"]): +# ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True From fb763d43332aaf15e96350cf1c25e2a1927423f1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 7 Sep 2025 18:16:29 -0700 Subject: [PATCH 169/325] Fix amd_min_version crash when cpu device. (#9754) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index cb6580f73..bbfc3c7a1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -294,6 +294,9 @@ def amd_min_version(device=None, min_rdna_version=0): if not is_amd(): return False + if is_device_cpu(device): + return False + arch = torch.cuda.get_device_properties(device).gcnArchName if arch.startswith('gfx') and len(arch) == 7: try: From bd1d9bcd5fcdb8379ce5a8020cb2b8f42de1b7c7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 8 Sep 2025 12:07:04 -0700 Subject: [PATCH 170/325] Add ZeroDivisionError catch for EasyCache logging statement (#9768) --- comfy_extras/nodes_easycache.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 9d2988f5f..c170e9fd9 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -162,7 +162,12 @@ def easycache_sample_wrapper(executor, *args, **kwargs): logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") total_steps = len(args[3])-1 - logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).") + # catch division by zero for log statement; sucks to crash after all sampling is done + try: + speedup = total_steps/(total_steps-easycache.total_steps_skipped) + except ZeroDivisionError: + speedup = 1.0 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).") easycache.reset() guider.model_options = orig_model_options From 97652d26b81f83fc9a3675be55ede7762fafb7bd Mon Sep 17 00:00:00 2001 From: contentis Date: Mon, 8 Sep 2025 21:08:18 +0200 Subject: [PATCH 171/325] Add explicit casting in apply_rope for Qwen VL (#9759) --- comfy/text_encoders/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 4c976058f..5e11956b5 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N def apply_rope(xq, xk, freqs_cis): + org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) - return q_embed, k_embed + return q_embed.to(org_dtype), k_embed.to(org_dtype) class Attention(nn.Module): From 103a12cb668303f197b22f52bb2981bb1539beea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:30:26 -0700 Subject: [PATCH 172/325] Support qwen inpaint controlnet. (#9772) --- comfy/controlnet.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index e3dfedf55..f08ff4b36 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -253,7 +253,10 @@ class ControlNet(ControlBase): to_concat = [] for c in self.extra_concat_orig: c = c.to(self.cond_hint.device) - c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center") + c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center") + if c.ndim < self.cond_hint.ndim: + c = c.unsqueeze(2) + c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2) to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) @@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}): def load_controlnet_qwen_instantx(sd, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) - control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1] + + extra_condition_channels = 0 + concat_mask = False + if control_latent_channels == 68: #inpaint controlnet + extra_condition_channels = control_latent_channels - 64 + concat_mask = True + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) latent_format = comfy.latent_formats.Wan21() extra_conds = [] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control def convert_mistoline(sd): From f73b176abd6b3e3b587b668fa6748107deef311c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 9 Sep 2025 21:40:29 +0300 Subject: [PATCH 173/325] add ByteDance video API nodes (#9712) --- comfy_api_nodes/nodes_bytedance.py | 697 ++++++++++++++++++++++++++++- 1 file changed, 686 insertions(+), 11 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index fb6aba7fa..064df2d10 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,6 +1,7 @@ import logging +import math from enum import Enum -from typing import Optional +from typing import Literal, Optional, Type, Union from typing_extensions import override import torch @@ -10,28 +11,53 @@ from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api_nodes.util.validation_utils import ( validate_image_aspect_ratio_range, get_number_of_images, + validate_image_dimensions, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, + EmptyRequest, HttpMethod, SynchronousOperation, + PollingOperation, + T, +) +from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + upload_images_to_comfyapi, + validate_string, + image_tensor_pair_to_batch, ) -from comfy_api_nodes.apinode_utils import download_url_to_image_tensor, upload_images_to_comfyapi, validate_string -BYTEPLUS_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" + +# Long-running tasks endpoints(e.g., video) +BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" +BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} class Text2ImageModelName(str, Enum): - seedream3 = "seedream-3-0-t2i-250415" + seedream_3 = "seedream-3-0-t2i-250415" class Image2ImageModelName(str, Enum): - seededit3 = "seededit-3-0-i2i-250628" + seededit_3 = "seededit-3-0-i2i-250628" + + +class Text2VideoModelName(str, Enum): + seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_lite = "seedance-1-0-lite-t2v-250428" + + +class Image2VideoModelName(str, Enum): + """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" + seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_lite = "seedance-1-0-lite-i2v-250428" class Text2ImageTaskCreationRequest(BaseModel): - model: Text2ImageModelName = Text2ImageModelName.seedream3 + model: Text2ImageModelName = Text2ImageModelName.seedream_3 prompt: str = Field(...) response_format: Optional[str] = Field("url") size: Optional[str] = Field(None) @@ -41,7 +67,7 @@ class Text2ImageTaskCreationRequest(BaseModel): class Image2ImageTaskCreationRequest(BaseModel): - model: Image2ImageModelName = Image2ImageModelName.seededit3 + model: Image2ImageModelName = Image2ImageModelName.seededit_3 prompt: str = Field(...) response_format: Optional[str] = Field("url") image: str = Field(..., description="Base64 encoded string or image URL") @@ -58,6 +84,52 @@ class ImageTaskCreationResponse(BaseModel): error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro + content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: Optional[TaskStatusError] = Field(None) + content: Optional[TaskStatusResult] = Field(None) + + RECOMMENDED_PRESETS = [ ("1024x1024 (1:1)", 1024, 1024), ("864x1152 (3:4)", 864, 1152), @@ -71,6 +143,25 @@ RECOMMENDED_PRESETS = [ ("Custom", None, None), ] +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, +} + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: @@ -81,6 +172,42 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: return response.data[0]["url"] +def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: + """Returns the video URL from the task status response if it exists.""" + if hasattr(response, "content") and response.content: + return response.content.video_url + return None + + +async def poll_until_finished( + auth_kwargs: dict[str, str], + task_id: str, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> TaskStatusResponse: + """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" + return await PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + completed_statuses=[ + "succeeded", + ], + failed_statuses=[ + "cancelled", + "failed", + ], + status_extractor=lambda response: response.status, + auth_kwargs=auth_kwargs, + result_url_extractor=get_video_url_from_task_status, + estimated_duration=estimated_duration, + node_id=node_id, + ).execute() + + class ByteDanceImageNode(comfy_io.ComfyNode): @classmethod @@ -94,7 +221,7 @@ class ByteDanceImageNode(comfy_io.ComfyNode): comfy_io.Combo.Input( "model", options=[model.value for model in Text2ImageModelName], - default=Text2ImageModelName.seedream3.value, + default=Text2ImageModelName.seedream_3.value, tooltip="Model name", ), comfy_io.String.Input( @@ -203,7 +330,7 @@ class ByteDanceImageNode(comfy_io.ComfyNode): } response = await SynchronousOperation( endpoint=ApiEndpoint( - path=BYTEPLUS_ENDPOINT, + path=BYTEPLUS_IMAGE_ENDPOINT, method=HttpMethod.POST, request_model=Text2ImageTaskCreationRequest, response_model=ImageTaskCreationResponse, @@ -227,7 +354,7 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): comfy_io.Combo.Input( "model", options=[model.value for model in Image2ImageModelName], - default=Image2ImageModelName.seededit3.value, + default=Image2ImageModelName.seededit_3.value, tooltip="Model name", ), comfy_io.Image.Input( @@ -313,7 +440,7 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): ) response = await SynchronousOperation( endpoint=ApiEndpoint( - path=BYTEPLUS_ENDPOINT, + path=BYTEPLUS_IMAGE_ENDPOINT, method=HttpMethod.POST, request_model=Image2ImageTaskCreationRequest, response_model=ImageTaskCreationResponse, @@ -324,12 +451,560 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) +class ByteDanceTextToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceTextToVideoNode", + display_name="ByteDance Text to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Text2VideoModelName], + default=Text2VideoModelName.seedance_1_pro.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + return await process_video_task( + request_model=Text2VideoTaskCreationRequest, + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageToVideoNode", + display_name="ByteDance Image to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on image and prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Image2VideoModelName], + default=Image2VideoModelName.seedance_1_pro.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "image", + tooltip="First frame to be used for the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceFirstLastFrameNode", + display_name="ByteDance First-Last-Frame to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and first and last frames.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[Image2VideoModelName.seedance_1_lite.value], + default=Image2VideoModelName.seedance_1_lite.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "first_frame", + tooltip="First frame to be used for the video.", + ), + comfy_io.Image.Input( + "last_frame", + tooltip="Last frame to be used for the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + for i in (first_frame, last_frame): + validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + download_urls = await upload_images_to_comfyapi( + image_tensor_pair_to_batch(first_frame, last_frame), + max_images=2, + mime_type="image/png", + auth_kwargs=auth_kwargs, + ) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[ + TaskTextContent(text=prompt), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), + ], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageReferenceNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageReferenceNode", + display_name="ByteDance Reference Images to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and reference images.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[Image2VideoModelName.seedance_1_lite.value], + default=Image2VideoModelName.seedance_1_lite.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "images", + tooltip="One to four images.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + images: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) + for image in images: + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + image_urls = await upload_images_to_comfyapi( + images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs + ) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--watermark {str(watermark).lower()}" + ) + x = [ + TaskTextContent(text=prompt), + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + ] + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=x, + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +async def process_video_task( + request_model: Type[T], + payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + auth_kwargs: dict, + node_id: str, + estimated_duration: int | None, +) -> comfy_io.NodeOutput: + initial_response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_TASK_ENDPOINT, + method=HttpMethod.POST, + request_model=request_model, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + response = await poll_until_finished( + auth_kwargs, + initial_response.id, + estimated_duration=estimated_duration, + node_id=node_id, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + + +def raise_if_text_params(prompt: str, text_params: list[str]) -> None: + for i in text_params: + if f"--{i} " in prompt: + raise ValueError( + f"--{i} is not allowed in the prompt, use the appropriated widget input to change this value." + ) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: return [ ByteDanceImageNode, ByteDanceImageEditNode, + ByteDanceTextToVideoNode, + ByteDanceImageToVideoNode, + ByteDanceFirstLastFrameNode, + ByteDanceImageReferenceNode, ] async def comfy_entrypoint() -> ByteDanceExtension: From b288fb0db88281532d813d4fb83f715f88b54ffc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:09:56 -0700 Subject: [PATCH 174/325] Small refactor of some vae code. (#9787) --- comfy/ldm/modules/diffusionmodules/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1fd12b35a..8f598a848 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -145,7 +145,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -183,7 +183,7 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb): + def forward(self, x, temb=None): h = x h = self.norm1(h) h = self.swish(h) From 206595f854c67538d5921d36326acbfeb69c5ac2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 9 Sep 2025 18:33:36 -0700 Subject: [PATCH 175/325] Change validate_inputs' output typehint to 'bool | str' and update docstrings (#9786) --- comfy_api/latest/_io.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e0ee943a7..f770109d5 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1190,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: - """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + def validate_inputs(cls, **kwargs) -> bool | str: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS. + + If the function returns a string, it will be used as the validation error message for the node. + """ raise NotImplementedError @classmethod def fingerprint_inputs(cls, **kwargs) -> Any: - """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED. + + If this function returns the same value as last run, the node will not be executed.""" raise NotImplementedError @classmethod From 5c33872e2f355e51adf212d5b5c83815b7fe77b0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 21:23:47 -0700 Subject: [PATCH 176/325] Fix issue on old torch. (#9791) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index ca1a83001..d48d9d642 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -426,7 +426,7 @@ class HunYuanDiTBlock(nn.Module): text_states_dim=1024, qk_norm=False, norm_layer=nn.LayerNorm, - qk_norm_layer=nn.RMSNorm, + qk_norm_layer=True, qkv_bias=True, skip_connection=True, timested_modulate=False, From 85e34643f874aec2ab9eed6a8499f2aefa81486e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:05:07 -0700 Subject: [PATCH 177/325] Support hunyuan image 2.1 regular model. (#9792) --- comfy/latent_formats.py | 5 + comfy/ldm/hunyuan_video/model.py | 102 +- comfy/ldm/hunyuan_video/vae.py | 136 ++ comfy/model_base.py | 24 + comfy/model_detection.py | 28 +- comfy/sd.py | 31 +- comfy/supported_models.py | 27 +- .../byt5_config_small_glyph.json | 22 + .../byt5_tokenizer/added_tokens.json | 127 ++ .../byt5_tokenizer/special_tokens_map.json | 150 +++ .../byt5_tokenizer/tokenizer_config.json | 1163 +++++++++++++++++ comfy/text_encoders/hunyuan_image.py | 100 ++ comfy_extras/nodes_hunyuan.py | 15 + nodes.py | 6 +- 14 files changed, 1906 insertions(+), 30 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/vae.py create mode 100644 comfy/text_encoders/byt5_config_small_glyph.json create mode 100644 comfy/text_encoders/byt5_tokenizer/added_tokens.json create mode 100644 comfy/text_encoders/byt5_tokenizer/special_tokens_map.json create mode 100644 comfy/text_encoders/byt5_tokenizer/tokenizer_config.json create mode 100644 comfy/text_encoders/hunyuan_image.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 0d84994b0..859ae8421 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -533,6 +533,11 @@ class Wan22(Wan21): 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 ]).view(1, self.latent_channels, 1, 1, 1) +class HunyuanImage21(LatentFormat): + latent_channels = 64 + latent_dimensions = 2 + scale_factor = 0.75289 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index da1011596..ca289c5bd 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -40,6 +40,7 @@ class HunyuanVideoParams: patch_size: list qkv_bias: bool guidance_embed: bool + byt5: bool class SelfAttentionRef(nn.Module): @@ -161,6 +162,30 @@ class TokenRefiner(nn.Module): x = self.individual_token_refiner(x, c, mask) return x + +class ByT5Mapper(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None): + super().__init__() + self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device) + self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device) + self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device) + self.use_res = use_res + self.act_fn = nn.GELU() + + def forward(self, x): + if self.use_res: + res = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x2 = self.act_fn(x) + x2 = self.fc3(x2) + if self.use_res: + x2 = x2 + res + return x2 + class HunyuanVideo(nn.Module): """ Transformer model for flow matching on sequences. @@ -185,9 +210,13 @@ class HunyuanVideo(nn.Module): self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations) + self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) @@ -215,6 +244,18 @@ class HunyuanVideo(nn.Module): ] ) + if params.byt5: + self.byt5_in = ByT5Mapper( + in_dim=1472, + out_dim=2048, + hidden_dim=2048, + out_dim1=self.hidden_size, + use_res=False, + dtype=dtype, device=device, operations=operations + ) + else: + self.byt5_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) @@ -226,7 +267,8 @@ class HunyuanVideo(nn.Module): txt_ids: Tensor, txt_mask: Tensor, timesteps: Tensor, - y: Tensor, + y: Tensor = None, + txt_byt5=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -250,13 +292,17 @@ class HunyuanVideo(nn.Module): if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) - vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) - vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + if self.vector_in is not None: + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + else: + vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims_txt = [(0, None, 1)] else: - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) modulation_dims = None modulation_dims_txt = None @@ -269,6 +315,12 @@ class HunyuanVideo(nn.Module): txt = self.txt_in(txt, timesteps, txt_mask) + if self.byt5_in is not None and txt_byt5 is not None: + txt_byt5 = self.byt5_in(txt_byt5) + txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt = torch.cat((txt, txt_byt5), dim=1) + txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -328,12 +380,16 @@ class HunyuanVideo(nn.Module): img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) - shape = initial_shape[-3:] + shape = initial_shape[-len(self.patch_size):] for i in range(len(shape)): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) - img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + if img.ndim == 8: + img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + else: + img = img.permute(0, 3, 1, 4, 2, 5) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3]) return img def img_ids(self, x): @@ -348,16 +404,30 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def img_ids_2d(self, x): + bs, c, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size[0] // 2)) // patch_size[0]) + w_len = ((w + (patch_size[1] // 2)) // patch_size[1]) + img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - img_ids = self.img_ids(x) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs = x.shape[0] + if len(self.patch_size) == 3: + img_ids = self.img_ids(x) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + else: + img_ids = self.img_ids_2d(x) + txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py new file mode 100644 index 000000000..8d406089b --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae.py @@ -0,0 +1,136 @@ +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class PixelShuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim >> 2, 3, 1, 1) + self.ratio = (in_dim << 2) // out_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h >> 1, w >> 1 + y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2) + r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2) + return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2) + + +class PixelUnshuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim << 2, 3, 1, 1) + self.scale = (out_dim << 2) // in_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h << 1, w << 1 + y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + return y + r + + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) + + def forward(self, x): + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, h, w).mean(2) + + return self.conv_out(F.silu(self.norm_out(x))) + skip + + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.up.append(stage) + + self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) + + def forward(self, z): + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/model_base.py b/comfy/model_base.py index 39a3344bc..993ff65e6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1408,3 +1408,27 @@ class QwenImage(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class HunyuanImage21(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 75552ede9..dbcbe5f5a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} + in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] + out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = [1, 2, 2] - dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 + dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w.shape[2:]) + dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict: + dit_config["vec_in_dim"] = 768 + dit_config["axes_dim"] = [16, 56, 56] + else: + dit_config["vec_in_dim"] = None + dit_config["axes_dim"] = [64, 64] + + dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_w.shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["num_heads"] = in_w.shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 256 dit_config["qkv_bias"] = True + if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict: + dit_config["byt5"] = True + else: + dit_config["byt5"] = False + guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index be5aa8dc8..9dd9a74d4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.hunyuan_video.vae import yaml import math import os @@ -48,6 +49,7 @@ import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image import comfy.model_patcher import comfy.lora @@ -328,6 +330,19 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -785,6 +800,7 @@ class CLIPType(Enum): ACE = 16 OMNIGEN2 = 17 QWEN_IMAGE = 18 + HUNYUAN_IMAGE = 19 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -806,6 +822,7 @@ class TEModel(Enum): GEMMA_2_2B = 9 QWEN25_3B = 10 QWEN25_7B = 11 + BYT5_SMALL_GLYPH = 12 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -823,6 +840,9 @@ def detect_te_model(sd): if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: + weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] + if weight.shape[0] == 384: + return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: return TEModel.GEMMA_2_2B @@ -937,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer elif te_model == TEModel.QWEN25_7B: - clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + if clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + else: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -982,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 75dad277d..aa953b462 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -20,6 +20,7 @@ import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image from . import supported_models_base from . import latent_formats @@ -1295,7 +1296,31 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class HunyuanImage21(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vec_in_dim": None, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] + sampling_settings = { + "shift": 5.0, + } + + latent_format = latent_formats.HunyuanImage21 + + memory_usage_factor = 7.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json new file mode 100644 index 000000000..0239c7164 --- /dev/null +++ b/comfy/text_encoders/byt5_config_small_glyph.json @@ -0,0 +1,22 @@ +{ + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 1510 +} diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json new file mode 100644 index 000000000..93c190b56 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json @@ -0,0 +1,127 @@ +{ + "": 259, + "": 359, + "": 360, + "": 361, + "": 362, + "": 363, + "": 364, + "": 365, + "": 366, + "": 367, + "": 368, + "": 269, + "": 369, + "": 370, + "": 371, + "": 372, + "": 373, + "": 374, + "": 375, + "": 376, + "": 377, + "": 378, + "": 270, + "": 379, + "": 380, + "": 381, + "": 382, + "": 383, + "": 271, + "": 272, + "": 273, + "": 274, + "": 275, + "": 276, + "": 277, + "": 278, + "": 260, + "": 279, + "": 280, + "": 281, + "": 282, + "": 283, + "": 284, + "": 285, + "": 286, + "": 287, + "": 288, + "": 261, + "": 289, + "": 290, + "": 291, + "": 292, + "": 293, + "": 294, + "": 295, + "": 296, + "": 297, + "": 298, + "": 262, + "": 299, + "": 300, + "": 301, + "": 302, + "": 303, + "": 304, + "": 305, + "": 306, + "": 307, + "": 308, + "": 263, + "": 309, + "": 310, + "": 311, + "": 312, + "": 313, + "": 314, + "": 315, + "": 316, + "": 317, + "": 318, + "": 264, + "": 319, + "": 320, + "": 321, + "": 322, + "": 323, + "": 324, + "": 325, + "": 326, + "": 327, + "": 328, + "": 265, + "": 329, + "": 330, + "": 331, + "": 332, + "": 333, + "": 334, + "": 335, + "": 336, + "": 337, + "": 338, + "": 266, + "": 339, + "": 340, + "": 341, + "": 342, + "": 343, + "": 344, + "": 345, + "": 346, + "": 347, + "": 348, + "": 267, + "": 349, + "": 350, + "": 351, + "": 352, + "": 353, + "": 354, + "": 355, + "": 356, + "": 357, + "": 358, + "": 268 +} diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json new file mode 100644 index 000000000..04fd58b5f --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json @@ -0,0 +1,150 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } +} diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json new file mode 100644 index 000000000..5b1fe24c1 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json @@ -0,0 +1,1163 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "259": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "260": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "261": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "262": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "263": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "264": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "265": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "266": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "267": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "268": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "269": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "270": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "271": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "272": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "273": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "274": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "275": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "276": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "277": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "278": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "279": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "280": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "281": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "282": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "283": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "284": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "285": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "286": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "287": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "288": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "289": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "290": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "291": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "292": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "293": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "294": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "295": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "296": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "297": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "298": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "299": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "300": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "301": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "302": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "303": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "304": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "305": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "306": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "307": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "308": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "309": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "310": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "311": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "312": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "313": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "314": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "315": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "316": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "317": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "318": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "319": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "320": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "321": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "322": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "323": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "324": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "325": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "326": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "327": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "328": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "329": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "330": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "331": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "332": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "333": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "334": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "335": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "336": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "337": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "338": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "339": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "340": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "341": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "342": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "343": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "344": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "345": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "346": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "347": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "348": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "349": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "350": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "351": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "352": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "353": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "354": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "355": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "356": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "357": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "358": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "359": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "360": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "361": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "362": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "363": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "364": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "365": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "366": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "367": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "368": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "369": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "370": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "371": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "372": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "373": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "374": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "375": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "376": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "377": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "378": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "379": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "380": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "381": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "382": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "383": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_ids": 0, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "ByT5Tokenizer", + "unk_token": "" +} diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py new file mode 100644 index 000000000..be396cae7 --- /dev/null +++ b/comfy/text_encoders/hunyuan_image.py @@ -0,0 +1,100 @@ +from comfy import sd1_clip +import comfy.text_encoders.llama +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from transformers import ByT5Tokenizer +import os +import re + +class ByT5SmallTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class HunyuanImageTokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # self.llama_template_images = "{}" + self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + + # ByT5 processing for HunyuanImage + text_prompt_texts = [] + pattern_quote_single = r'\'(.*?)\'' + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_single = re.findall(pattern_quote_single, text) + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if len(text_prompt_texts) > 0: + out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs) + return out + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ByT5SmallModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + + +class HunyuanImageTEModel(QwenImageTEModel): + def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + if byt5: + self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options) + else: + self.byt5_small = None + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs) + if self.byt5_small is not None and "byt5" in token_weight_pairs: + out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) + extra["conditioning_byt5small"] = out[0] + return cond, p, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + if self.byt5_small is not None: + self.byt5_small.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + if self.byt5_small is not None: + self.byt5_small.reset_clip_options() + + def load_sd(self, sd): + if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd: + return self.byt5_small.load_sd(sd) + else: + return super().load_sd(sd) + +def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): + class QwenImageTEModel_(HunyuanImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d7278e7a7..ce031ceb2 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -113,6 +113,20 @@ class HunyuanImageToVideo: out_latent["samples"] = latent return (positive, out_latent) +class EmptyHunyuanImageLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent" + + def generate(self, width, height, batch_size=1): + latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return ({"samples":latent}, ) NODE_CLASS_MAPPINGS = { @@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = { "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, + "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, } diff --git a/nodes.py b/nodes.py index 6c2f9dd14..2befb4b75 100644 --- a/nodes.py +++ b/nodes.py @@ -925,7 +925,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -953,7 +953,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -963,7 +963,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From 70fc0425b36515926c6414aee9f2269b27880cc2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 10 Sep 2025 14:09:16 +0800 Subject: [PATCH 178/325] Update template to 0.1.76 (#9793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3008a5dc3..ea1931d78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.75 +comfyui-workflow-templates==0.1.76 comfyui-embedded-docs==0.2.6 torch torchsde From 543888d3d84a6ec4c4273838d5179845840e3226 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:15:34 -0700 Subject: [PATCH 179/325] Fix lowvram issue with hunyuan image vae. (#9794) --- comfy/ldm/hunyuan_video/vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py index 8d406089b..40c12b183 100644 --- a/comfy/ldm/hunyuan_video/vae.py +++ b/comfy/ldm/hunyuan_video/vae.py @@ -65,7 +65,7 @@ class Encoder(nn.Module): self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) - self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) def forward(self, x): @@ -120,7 +120,7 @@ class Decoder(nn.Module): ch = nxt self.up.append(stage) - self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) def forward(self, z): From de44b95db6c7ef107f26e7edf30748b608afaa48 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:06:47 +0300 Subject: [PATCH 180/325] add StabilityAudio API nodes (#9749) --- comfy_api_nodes/apinode_utils.py | 65 +++++ comfy_api_nodes/apis/stability_api.py | 22 ++ comfy_api_nodes/nodes_stability.py | 312 ++++++++++++++++++++++- comfy_api_nodes/util/validation_utils.py | 20 +- 4 files changed, 415 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index f953f86df..37438f835 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi( return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(io.BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + def audio_to_base64_string( audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" ) -> str: diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py index 47c87daec..718360187 100644 --- a/comfy_api_nodes/apis/stability_api.py +++ b/comfy_api_nodes/apis/stability_api.py @@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel): class StabilityAsyncResponse(BaseModel): id: Optional[str] = Field(None) + + +class StabilityTextToAudioRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + duration: int = Field(190, ge=1, le=190) + seed: int = Field(0, ge=0, le=4294967294) + steps: int = Field(8, ge=4, le=8) + output_format: str = Field("wav") + + +class StabilityAudioToAudioRequest(StabilityTextToAudioRequest): + strength: float = Field(0.01, ge=0.01, le=1.0) + + +class StabilityAudioInpaintRequest(StabilityTextToAudioRequest): + mask_start: int = Field(30, ge=0, le=190) + mask_end: int = Field(190, ge=0, le=190) + + +class StabilityAudioResponse(BaseModel): + audio: Optional[str] = Field(None) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index e05cb6bb2..5ba5ed986 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -2,7 +2,7 @@ from inspect import cleandoc from typing import Optional from typing_extensions import override -from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api.latest import ComfyExtension, Input, io as comfy_io from comfy_api_nodes.apis.stability_api import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, @@ -15,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import ( Stability_SD3_5_Model, Stability_SD3_5_GenerationMode, get_stability_style_presets, + StabilityTextToAudioRequest, + StabilityAudioToAudioRequest, + StabilityAudioInpaintRequest, + StabilityAudioResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -27,7 +31,10 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, tensor_to_bytesio, validate_string, + audio_bytes_to_audio_input, + audio_input_to_mp3, ) +from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 @@ -649,6 +656,306 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode): return comfy_io.NodeOutput(returned_image) +class StabilityTextToAudio(comfy_io.ComfyNode): + """Generates high-quality music and sound effects from text descriptions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityTextToAudio", + display_name="Stability AI Text To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", + method=HttpMethod.POST, + request_model=StabilityTextToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioToAudio(comfy_io.ComfyNode): + """Transforms existing audio samples into new high-quality compositions using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioToAudio", + display_name="Stability AI Audio To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Float.Input( + "strength", + default=1, + min=0.01, + max=1.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Parameter controls how much influence the audio parameter has on the generated audio.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + validate_audio_duration(audio, 6, 190) + payload = StabilityAudioToAudioRequest( + prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", + method=HttpMethod.POST, + request_model=StabilityAudioToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioInpaint(comfy_io.ComfyNode): + """Transforms part of existing audio sample using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioInpaint", + display_name="Stability AI Audio Inpaint", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Int.Input( + "mask_start", + default=30, + min=0, + max=190, + step=1, + optional=True, + ), + comfy_io.Int.Input( + "mask_end", + default=190, + min=0, + max=190, + step=1, + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + audio: Input.Audio, + duration: int, + seed: int, + steps: int, + mask_start: int, + mask_end: int, + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + if mask_end <= mask_start: + raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})") + validate_audio_duration(audio, 6, 190) + + payload = StabilityAudioInpaintRequest( + prompt=prompt, + model=model, + duration=duration, + seed=seed, + steps=steps, + mask_start=mask_start, + mask_end=mask_end, + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", + method=HttpMethod.POST, + request_model=StabilityAudioInpaintRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + class StabilityExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: @@ -658,6 +965,9 @@ class StabilityExtension(ComfyExtension): StabilityUpscaleConservativeNode, StabilityUpscaleCreativeNode, StabilityUpscaleFastNode, + StabilityTextToAudio, + StabilityAudioToAudio, + StabilityAudioInpaint, ] diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 606b794bf..ca913e9b3 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,7 +2,7 @@ import logging from typing import Optional import torch -from comfy_api.input.video_types import VideoInput +from comfy_api.latest import Input def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: @@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness( def validate_video_dimensions( - video: VideoInput, + video: Input.Video, min_width: Optional[int] = None, max_width: Optional[int] = None, min_height: Optional[int] = None, @@ -126,7 +126,7 @@ def validate_video_dimensions( def validate_video_duration( - video: VideoInput, + video: Input.Video, min_duration: Optional[float] = None, max_duration: Optional[float] = None, ): @@ -151,3 +151,17 @@ def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 return len(images) + + +def validate_audio_duration( + audio: Input.Audio, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, +) -> None: + sr = int(audio["sample_rate"]) + dur = int(audio["waveform"].shape[-1]) / sr + eps = 1.0 / sr + if min_duration is not None and dur + eps < min_duration: + raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") + if max_duration is not None and dur - eps > max_duration: + raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") From 8d7c930246bd33c32eb957b01ab0d364af6e81c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Sep 2025 10:51:02 -0400 Subject: [PATCH 181/325] ComfyUI version v0.3.58 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4cc3c8647..37361bd75 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.57" +__version__ = "0.3.58" diff --git a/pyproject.toml b/pyproject.toml index d75cd04a2..f02ab9126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.57" +version = "0.3.58" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 9b0553809cbac084aac0576892aca3e448eb07c7 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:13:18 +0300 Subject: [PATCH 182/325] add new ByteDanceSeedream (4.0) node (#9802) --- comfy_api_nodes/nodes_bytedance.py | 208 ++++++++++++++++++++++++++++- 1 file changed, 207 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 064df2d10..369a3a4fe 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -77,6 +77,22 @@ class Image2ImageTaskCreationRequest(BaseModel): watermark: Optional[bool] = Field(True) +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field("seedream-4-0-250828") + prompt: str = Field(...) + response_format: str = Field("url") + image: Optional[list[str]] = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + class ImageTaskCreationResponse(BaseModel): model: str = Field(...) created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") @@ -143,6 +159,19 @@ RECOMMENDED_PRESETS = [ ("Custom", None, None), ] +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + # The time in this dictionary are given for 10 seconds duration. VIDEO_TASKS_EXECUTION_TIME = { "seedance-1-0-lite-t2v-250428": { @@ -348,7 +377,7 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): return comfy_io.Schema( node_id="ByteDanceImageEditNode", display_name="ByteDance Image Edit", - category="api node/video/ByteDance", + category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ comfy_io.Combo.Input( @@ -451,6 +480,182 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) +class ByteDanceSeedreamNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceSeedreamNode", + display_name="ByteDance Seedream 4", + category="api node/image/ByteDance", + description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["seedream-4-0-250828"], + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for creating or editing an image.", + ), + comfy_io.Image.Input( + "image", + tooltip="Input image(s) for image-to-image generation. " + "List of 1-10 images for single or multi-reference generation.", + optional=True, + ), + comfy_io.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS_SEEDREAM_4], + tooltip="Pick a recommended size. Select Custom to use the width and height below.", + ), + comfy_io.Int.Input( + "width", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Int.Input( + "height", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Combo.Input( + "sequential_image_generation", + options=["disabled", "auto"], + tooltip="Group image generation mode. " + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", + optional=True, + ), + comfy_io.Int.Input( + "max_images", + default=1, + min=1, + max=15, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " + "Total images (input + generated) cannot exceed 15.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: torch.Tensor = None, + size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], + width: int = 2048, + height: int = 2048, + sequential_image_generation: str = "disabled", + max_images: int = 1, + seed: int = 0, + watermark: bool = True, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): + raise ValueError( + f"Custom size out of range: {w}x{h}. " + "Both width and height must be between 1024 and 4096 pixels." + ) + n_input_images = get_number_of_images(image) if image is not None else 0 + if n_input_images > 10: + raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + if sequential_image_generation == "auto" and n_input_images + max_images > 15: + raise ValueError( + "The maximum number of generated images plus the number of reference images cannot exceed 15." + ) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + reference_images_urls = [] + if n_input_images: + for i in image: + validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) + reference_images_urls = (await upload_images_to_comfyapi( + image, + max_images=n_input_images, + mime_type="image/png", + auth_kwargs=auth_kwargs, + )) + payload = Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, + ) + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_IMAGE_ENDPOINT, + method=HttpMethod.POST, + request_model=Seedream4TaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + + if len(response.data) == 1: + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + return comfy_io.NodeOutput( + torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data]) + ) + + class ByteDanceTextToVideoNode(comfy_io.ComfyNode): @classmethod @@ -1001,6 +1206,7 @@ class ByteDanceExtension(ComfyExtension): return [ ByteDanceImageNode, ByteDanceImageEditNode, + ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, ByteDanceFirstLastFrameNode, From df34f1549a431c85a6326e87075a206228697cde Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 11 Sep 2025 05:16:41 +0800 Subject: [PATCH 183/325] Update template to 0.1.78 (#9806) * Update template to 0.1.77 * Update template to 0.1.78 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ea1931d78..d31df0fec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.76 +comfyui-workflow-templates==0.1.78 comfyui-embedded-docs==0.2.6 torch torchsde From 72212fef660bcd7d9702fa52011d089c027a64d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Sep 2025 17:25:41 -0400 Subject: [PATCH 184/325] ComfyUI version 0.3.59 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 37361bd75..ee58205f5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.58" +__version__ = "0.3.59" diff --git a/pyproject.toml b/pyproject.toml index f02ab9126..a7fc1a5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.58" +version = "0.3.59" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From e01e99d075852b94e93f27ea64ab862a49a7d2cc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 10 Sep 2025 20:17:34 -0700 Subject: [PATCH 185/325] Support hunyuan image distilled model. (#9807) --- comfy/ldm/hunyuan_video/model.py | 14 ++++++++++++++ comfy/model_detection.py | 12 ++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index ca289c5bd..7732182a4 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -41,6 +41,7 @@ class HunyuanVideoParams: qkv_bias: bool guidance_embed: bool byt5: bool + meanflow: bool class SelfAttentionRef(nn.Module): @@ -256,6 +257,11 @@ class HunyuanVideo(nn.Module): else: self.byt5_in = None + if params.meanflow: + self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.time_r_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) @@ -282,6 +288,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if self.time_r_in is not None: + w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved + if len(w) > 0: + timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] + timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) + vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) + vec = (vec + vec_r) / 2 + if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) ref_latent = self.img_in(ref_latent) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dbcbe5f5a..fe983cede 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels dit_config["patch_size"] = list(in_w.shape[2:]) dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) - if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict: + if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): dit_config["vec_in_dim"] = 768 - dit_config["axes_dim"] = [16, 56, 56] else: dit_config["vec_in_dim"] = None + + if len(dit_config["patch_size"]) == 2: dit_config["axes_dim"] = [64, 64] + else: + dit_config["axes_dim"] = [16, 56, 56] + + if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["meanflow"] = True + else: + dit_config["meanflow"] = False dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] dit_config["hidden_size"] = in_w.shape[0] From df6850fae8a75126cb7a645e38d58cebcfd51096 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 12 Sep 2025 02:59:26 +0800 Subject: [PATCH 186/325] Update template to 0.1.81 (#9811) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d31df0fec..0e21967ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.78 +comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch torchsde From 18de0b28305fd8bf002d74e91c0630bd76b01d6b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:33:02 -0700 Subject: [PATCH 187/325] Fast preview for hunyuan image. (#9814) --- comfy/latent_formats.py | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 859ae8421..f975b5e11 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -538,6 +538,74 @@ class HunyuanImage21(LatentFormat): latent_dimensions = 2 scale_factor = 0.75289 + latent_rgb_factors = [ + [-0.0154, -0.0397, -0.0521], + [ 0.0005, 0.0093, 0.0006], + [-0.0805, -0.0773, -0.0586], + [-0.0494, -0.0487, -0.0498], + [-0.0212, -0.0076, -0.0261], + [-0.0179, -0.0417, -0.0505], + [ 0.0158, 0.0310, 0.0239], + [ 0.0409, 0.0516, 0.0201], + [ 0.0350, 0.0553, 0.0036], + [-0.0447, -0.0327, -0.0479], + [-0.0038, -0.0221, -0.0365], + [-0.0423, -0.0718, -0.0654], + [ 0.0039, 0.0368, 0.0104], + [ 0.0655, 0.0217, 0.0122], + [ 0.0490, 0.1638, 0.2053], + [ 0.0932, 0.0829, 0.0650], + [-0.0186, -0.0209, -0.0135], + [-0.0080, -0.0076, -0.0148], + [-0.0284, -0.0201, 0.0011], + [-0.0642, -0.0294, -0.0777], + [-0.0035, 0.0076, -0.0140], + [ 0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [ 0.0068, 0.0218, -0.0023], + [-0.0726, -0.0486, -0.0519], + [ 0.0260, 0.0295, 0.0263], + [ 0.0250, 0.0333, 0.0341], + [ 0.0168, -0.0120, -0.0174], + [ 0.0226, 0.1037, 0.0114], + [ 0.2577, 0.1906, 0.1604], + [-0.0646, -0.0137, -0.0018], + [-0.0112, 0.0309, 0.0358], + [-0.0347, 0.0146, -0.0481], + [ 0.0234, 0.0179, 0.0201], + [ 0.0157, 0.0313, 0.0225], + [ 0.0423, 0.0675, 0.0524], + [-0.0031, 0.0027, -0.0255], + [ 0.0447, 0.0555, 0.0330], + [-0.0152, 0.0103, 0.0299], + [-0.0755, -0.0489, -0.0635], + [ 0.0853, 0.0788, 0.1017], + [-0.0272, -0.0294, -0.0471], + [ 0.0440, 0.0400, -0.0137], + [ 0.0335, 0.0317, -0.0036], + [-0.0344, -0.0621, -0.0984], + [-0.0127, -0.0630, -0.0620], + [-0.0648, 0.0360, 0.0924], + [-0.0781, -0.0801, -0.0409], + [ 0.0363, 0.0613, 0.0499], + [ 0.0238, 0.0034, 0.0041], + [-0.0135, 0.0258, 0.0310], + [ 0.0614, 0.1086, 0.0589], + [ 0.0428, 0.0350, 0.0205], + [ 0.0153, 0.0173, -0.0018], + [-0.0288, -0.0455, -0.0091], + [ 0.0344, 0.0109, -0.0157], + [-0.0205, -0.0247, -0.0187], + [ 0.0487, 0.0126, 0.0064], + [-0.0220, -0.0013, 0.0074], + [-0.0203, -0.0094, -0.0048], + [-0.0719, 0.0429, -0.0442], + [ 0.1042, 0.0497, 0.0356], + [-0.0659, -0.0578, -0.0280], + [-0.0060, -0.0322, -0.0234]] + + latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 From 33bd9ed9cb941127b335244c6cc0a8cdc1ac1696 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Sep 2025 21:43:20 -0700 Subject: [PATCH 188/325] Implement hunyuan image refiner model. (#9817) --- comfy/latent_formats.py | 5 + comfy/ldm/hunyuan_video/model.py | 11 +- comfy/ldm/hunyuan_video/vae_refiner.py | 268 ++++++++++++++++++++ comfy/ldm/models/autoencoder.py | 6 + comfy/ldm/modules/diffusionmodules/model.py | 10 +- comfy/model_base.py | 20 ++ comfy/sd.py | 17 +- comfy/supported_models.py | 19 +- comfy_extras/nodes_hunyuan.py | 23 ++ 9 files changed, 367 insertions(+), 12 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/vae_refiner.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f975b5e11..894540879 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat): latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] +class HunyuanImage21Refiner(LatentFormat): + latent_channels = 64 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 7732182a4..ca86b8bb1 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -278,6 +278,7 @@ class HunyuanVideo(nn.Module): guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, + disable_time_r=False, control=None, transformer_options={}, ) -> Tensor: @@ -288,7 +289,7 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - if self.time_r_in is not None: + if (self.time_r_in is not None) and (not disable_time_r): w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved if len(w) > 0: timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] @@ -428,14 +429,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -443,5 +444,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py new file mode 100644 index 000000000..e3fff9bbe --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -0,0 +1,268 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +import comfy.ops +import comfy.ldm.models.autoencoder +ops = comfy.ops.disable_weight_init + +class RMS_norm(nn.Module): + def __init__(self, dim): + super().__init__() + shape = (dim, 1, 1, 1) + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.empty(shape)) + + def forward(self, x): + return F.normalize(x, dim=1) * self.scale * self.gamma + +class DnSmpl(nn.Module): + def __init__(self, ic, oc, tds=True): + super().__init__() + fct = 2 * 2 * 2 if tds else 1 * 2 * 2 + assert oc % fct == 0 + self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + + self.tds = tds + self.gs = fct * ic // oc + + def forward(self, x): + r1 = 2 if self.tds else 1 + h = self.conv(x) + + if self.tds: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) + hf = hf.permute(0, 4, 6, 1, 2, 3, 5) + hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) + hf = torch.cat([hf, hf], dim=1) + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nf = frms // r1 + hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) + hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2) + xf = xf.permute(0, 4, 6, 1, 2, 3, 5) + xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) + B, C, T, H, W = xf.shape + xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) + + xn = x[:, :, 1:, :, :] + b, ci, frms, ht, wd = xn.shape + nf = frms // r1 + xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) + xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = xn.shape + xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) + sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = sc.shape + sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + return h + sc + + +class UpSmpl(nn.Module): + def __init__(self, ic, oc, tus=True): + super().__init__() + fct = 2 * 2 * 2 if tus else 1 * 2 * 2 + self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + + self.tus = tus + self.rp = fct * oc // ic + + def forward(self, x): + r1 = 2 if self.tus else 1 + h = self.conv(x) + + if self.tus: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + nc = c // (2 * 2) + hf = hf.reshape(b, 2, 2, nc, f, ht, wd) + hf = hf.permute(0, 3, 4, 5, 1, 6, 2) + hf = hf.reshape(b, nc, f, ht * 2, wd * 2) + hf = hf[:, : hf.shape[1] // 2] + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nc = c // (r1 * 2 * 2) + hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) + hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) + b, c, f, ht, wd = xf.shape + nc = c // (2 * 2) + xf = xf.reshape(b, 2, 2, nc, f, ht, wd) + xf = xf.permute(0, 3, 4, 5, 1, 6, 2) + xf = xf.reshape(b, nc, f, ht * 2, wd * 2) + + xn = x[:, :, 1:, :, :] + xn = xn.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = xn.shape + nc = c // (r1 * 2 * 2) + xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) + xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + sc = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = sc.shape + nc = c // (r1 * 2 * 2) + sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) + sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) + sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + return h + sc + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + + self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() + + def forward(self, x): + x = x.unsqueeze(2) + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, t, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, t, h, w).mean(2) + + out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, 3) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = (ffactor_temporal >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + ch = nxt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, 3) + + def forward(self, z): + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 13bd6e16b..611d36a1b 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module): z = posterior.mode() return z, None +class EmptyRegularizer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, None class AbstractAutoencoder(torch.nn.Module): """ diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8f598a848..4245eedca 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -145,7 +145,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout=0.0, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -153,7 +153,7 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.swish = torch.nn.SiLU(inplace=True) - self.norm1 = Normalize(in_channels) + self.norm1 = norm_op(in_channels) self.conv1 = conv_op(in_channels, out_channels, kernel_size=3, @@ -162,7 +162,7 @@ class ResnetBlock(nn.Module): if temb_channels > 0: self.temb_proj = ops.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) + self.norm2 = norm_op(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = conv_op(out_channels, out_channels, @@ -305,11 +305,11 @@ def vae_attention(): return normal_attention class AttnBlock(nn.Module): - def __init__(self, in_channels, conv_op=ops.Conv2d): + def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels - self.norm = Normalize(in_channels) + self.norm = norm_op(in_channels) self.q = conv_op(in_channels, in_channels, kernel_size=1, diff --git a/comfy/model_base.py b/comfy/model_base.py index 993ff65e6..c69a9d1ad 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1432,3 +1432,23 @@ class HunyuanImage21(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanImage21Refiner(HunyuanImage21): + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = self.process_latent_in(image) + image = utils.resize_to_batch_size(image, noise.shape[0]) + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(True) + return out diff --git a/comfy/sd.py b/comfy/sd.py index 9dd9a74d4..02ddc7239 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -285,6 +285,7 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False + self.not_video = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -409,6 +410,20 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.downscale_ratio = 16 + self.upscale_ratio = 16 + self.latent_dim = 3 + self.not_video = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -669,7 +684,7 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index aa953b462..ba1b8c313 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1321,6 +1321,23 @@ class HunyuanImage21(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +class HunyuanImage21Refiner(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "patch_size": [1, 1, 1], + "vec_in_dim": None, + } + + sampling_settings = { + "shift": 1.0, + } + + latent_format = latent_formats.HunyuanImage21Refiner + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21Refiner(self, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index ce031ceb2..351a7e2cb 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent: latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +class HunyuanRefinerLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent": ("LATENT", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "execute" + + def execute(self, positive, negative, latent): + latent = latent["samples"] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent}) + out_latent = {} + out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, @@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = { "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, + "HunyuanRefinerLatent": HunyuanRefinerLatent, } From 15ec9ea958d1c5d374add598b571a585541d4863 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 11 Sep 2025 21:44:20 -0700 Subject: [PATCH 189/325] Add Output to V3 Combo type to match what is possible with V1 (#9813) --- comfy_api/latest/_io.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index f770109d5..4826818df 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -331,7 +331,7 @@ class String(ComfyTypeIO): }) @comfytype(io_type="COMBO") -class Combo(ComfyTypeI): +class Combo(ComfyTypeIO): Type = str class Input(WidgetInput): """Combo input (dropdown).""" @@ -360,6 +360,14 @@ class Combo(ComfyTypeI): "remote": self.remote.as_dict() if self.remote else None, }) + class Output(Output): + def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.options = options if options is not None else [] + + @property + def io_type(self): + return self.options @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): From d6b977b2e680e98ad18a37ee13783da4f30e15f4 Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Thu, 11 Sep 2025 21:46:01 -0700 Subject: [PATCH 190/325] Bump frontend to 1.26.11 (#9809) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0e21967ef..de5af5fac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.11 +comfyui-frontend-package==1.26.11 comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch From fd2b820ec28e9575877dc6c51949b2d28dc78894 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:03:08 -0700 Subject: [PATCH 191/325] Add noise augmentation to hunyuan image refiner. (#9831) This was missing and should help with colors being blown out. --- comfy/model_base.py | 4 ++++ comfy_extras/nodes_hunyuan.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index c69a9d1ad..8422051bf 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1437,6 +1437,7 @@ class HunyuanImage21Refiner(HunyuanImage21): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) device = kwargs["device"] if image is None: @@ -1446,6 +1447,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) + if noise_augmentation > 0: + noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) + image = noise_augmentation * noise + (1.0 - noise_augmentation) * image return image def extra_conds(self, **kwargs): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 351a7e2cb..db398cdf1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -134,6 +134,7 @@ class HunyuanRefinerLatent: return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "latent": ("LATENT", ), + "noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -141,11 +142,10 @@ class HunyuanRefinerLatent: FUNCTION = "execute" - def execute(self, positive, negative, latent): + def execute(self, positive, negative, latent, noise_augmentation): latent = latent["samples"] - - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) out_latent = {} out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) return (positive, negative, out_latent) From e600520f8aa583c79caa286a8d7d584edc3d059b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:35:34 -0700 Subject: [PATCH 192/325] Fix hunyuan refiner blownout colors at noise aug less than 0.25 (#9832) --- comfy/model_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8422051bf..4176bca25 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1449,7 +1449,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = utils.resize_to_batch_size(image, noise.shape[0]) if noise_augmentation > 0: noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) - image = noise_augmentation * noise + (1.0 - noise_augmentation) * image + image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image + else: + image = 0.75 * image return image def extra_conds(self, **kwargs): From 7757d5a657cbe9b22d1f3538ee0bc5387d3f5459 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:40:12 -0700 Subject: [PATCH 193/325] Set default hunyuan refiner shift to 4.0 (#9833) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ba1b8c313..472ea0ae9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1329,7 +1329,7 @@ class HunyuanImage21Refiner(HunyuanVideo): } sampling_settings = { - "shift": 1.0, + "shift": 4.0, } latent_format = latent_formats.HunyuanImage21Refiner From 0aa074a420c450fd7793d83c6f8d66009a1ca2a2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:29:03 +0300 Subject: [PATCH 194/325] add kling-v2-1 model to the KlingStartEndFrame node (#9630) --- comfy_api_nodes/nodes_kling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9fa390985..5f55b2cc9 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -846,6 +846,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), + "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), } @classmethod From 45bc1f5c00307f3e85871ecfb46acaa2365b0096 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:37:31 +0300 Subject: [PATCH 195/325] convert Minimax API nodes to the V3 schema (#9693) --- comfy_api_nodes/nodes_minimax.py | 732 ++++++++++++++++--------------- 1 file changed, 378 insertions(+), 354 deletions(-) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index bb3c9e710..bf560661c 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,9 +1,10 @@ from inspect import cleandoc -from typing import Union +from typing import Optional import logging import torch -from comfy.comfy_types.node_typing import IO +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( MinimaxVideoGenerationRequest, @@ -11,7 +12,7 @@ from comfy_api_nodes.apis import ( MinimaxFileRetrieveResponse, MinimaxTaskResultResponse, SubjectReferenceItem, - MiniMaxModel + MiniMaxModel, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -31,372 +32,398 @@ from server import PromptServer I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 -class MinimaxTextToVideoNode: + +async def _generate_mm_video( + *, + auth: dict[str, str], + node_id: str, + prompt_text: str, + seed: int, + model: str, + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + average_duration: Optional[int] = None, +) -> comfy_io.NodeOutput: + if image is None: + validate_string(prompt_text, field_name="prompt_text") + # upload image, if passed in + image_url = None + if image is not None: + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] + + # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model + subject_reference = None + if subject is not None: + subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] + subject_reference = [SubjectReferenceItem(image=subject_url)] + + + video_generate_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/video_generation", + method=HttpMethod.POST, + request_model=MinimaxVideoGenerationRequest, + response_model=MinimaxVideoGenerationResponse, + ), + request=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + subject_reference=subject_reference, + prompt_optimizer=None, + ), + auth_kwargs=auth, + ) + response = await video_generate_operation.execute() + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + video_generate_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/minimax/query/video_generation", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxTaskResultResponse, + query_params={"task_id": task_id}, + ), + completed_statuses=["Success"], + failed_statuses=["Fail"], + status_extractor=lambda x: x.status.value, + estimated_duration=average_duration, + node_id=node_id, + auth_kwargs=auth, + ) + task_result = await video_generate_operation.execute() + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_retrieve_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/files/retrieve", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxFileRetrieveResponse, + query_params={"file_id": int(file_id)}, + ), + request=EmptyRequest(), + auth_kwargs=auth, + ) + file_result = await file_retrieve_operation.execute() + + file_url = file_result.file.download_url + if file_url is None: + raise Exception( + f"No video was found in the response. Full response: {file_result.model_dump()}" + ) + logging.info("Generated video URL: %s", file_url) + if node_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, node_id) + + # Download and return as VideoFromFile + video_io = await download_url_to_bytesio(file_url) + if video_io is None: + error_msg = f"Failed to download video from {file_url}" + logging.error(error_msg) + raise Exception(error_msg) + return comfy_io.NodeOutput(VideoFromFile(video_io)) + + +class MinimaxTextToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxTextToVideoNode", + display_name="MiniMax Text to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["T2V-01", "T2V-01-Director"], + default="T2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "T2V-01", - "T2V-01-Director", - ], - { - "default": "T2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + prompt_text: str, + model: str = "T2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - - async def generate_video( - self, - prompt_text, - seed=0, - model="T2V-01", - image: torch.Tensor=None, # used for ImageToVideo - subject: torch.Tensor=None, # used for SubjectToVideo - unique_id: Union[str, None]=None, - **kwargs, - ): - ''' - Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments. - ''' - if image is None: - validate_string(prompt_text, field_name="prompt_text") - # upload image, if passed in - image_url = None - if image is not None: - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0] - - # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model - subject_reference = None - if subject is not None: - subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0] - subject_reference = [SubjectReferenceItem(image=subject_url)] - - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( - model=MiniMaxModel(model), - prompt=prompt_text, - callback_url=None, - first_frame_image=image_url, - subject_reference=subject_reference, - prompt_optimizer=None, - ), - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=None, + average_duration=T2V_AVERAGE_DURATION, ) - response = await video_generate_operation.execute() - - task_id = response.task_id - if not task_id: - raise Exception(f"MiniMax generation failed: {response.base_resp}") - - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], - status_extractor=lambda x: x.status.value, - estimated_duration=self.AVERAGE_DURATION, - node_id=unique_id, - auth_kwargs=kwargs, - ) - task_result = await video_generate_operation.execute() - - file_id = task_result.file_id - if file_id is None: - raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=kwargs, - ) - file_result = await file_retrieve_operation.execute() - - file_url = file_result.file.download_url - if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info(f"Generated video URL: {file_url}") - if unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, unique_id) - - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return (VideoFromFile(video_io),) -class MinimaxImageToVideoNode(MinimaxTextToVideoNode): +class MinimaxImageToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = I2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxImageToVideoNode", + display_name="MiniMax Image to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="Image to use as first frame of video generation", + ), + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["I2V-01-Director", "I2V-01", "I2V-01-live"], + default="I2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as first frame of video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "I2V-01-Director", - "I2V-01", - "I2V-01-live", - ], - { - "default": "I2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + image: torch.Tensor, + prompt_text: str, + model: str = "I2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=image, + subject=None, + average_duration=I2V_AVERAGE_DURATION, + ) -class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): +class MinimaxSubjectToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxSubjectToVideoNode", + display_name="MiniMax Subject to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "subject", + tooltip="Image of subject to reference for video generation", + ), + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["S2V-01"], + default="S2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "subject": ( - IO.IMAGE, - { - "tooltip": "Image of subject to reference video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "S2V-01", - ], - { - "default": "S2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + subject: torch.Tensor, + prompt_text: str, + model: str = "S2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=subject, + average_duration=T2V_AVERAGE_DURATION, + ) -class MinimaxHailuoVideoNode: +class MinimaxHailuoVideoNode(comfy_io.ComfyNode): """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation.", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxHailuoVideoNode", + display_name="MiniMax Hailuo Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation.", ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, ), - "first_frame_image": ( - IO.IMAGE, - { - "tooltip": "Optional image to use as the first frame to generate a video." - }, + comfy_io.Image.Input( + "first_frame_image", + tooltip="Optional image to use as the first frame to generate a video.", + optional=True, ), - "prompt_optimizer": ( - IO.BOOLEAN, - { - "tooltip": "Optimize prompt to improve generation quality when needed.", - "default": True, - }, + comfy_io.Boolean.Input( + "prompt_optimizer", + default=True, + tooltip="Optimize prompt to improve generation quality when needed.", + optional=True, ), - "duration": ( - IO.COMBO, - { - "tooltip": "The length of the output video in seconds.", - "default": 6, - "options": [6, 10], - }, + comfy_io.Combo.Input( + "duration", + options=[6, 10], + default=6, + tooltip="The length of the output video in seconds.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "tooltip": "The dimensions of the video display. " - "1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.", - "default": "768P", - "options": ["768P", "1080P"], - }, + comfy_io.Combo.Input( + "resolution", + options=["768P", "1080P"], + default="768P", + tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt_text: str, + seed: int = 0, + first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo + prompt_optimizer: bool = True, + duration: int = 6, + resolution: str = "768P", + model: str = "MiniMax-Hailuo-02", + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - - async def generate_video( - self, - prompt_text, - seed=0, - first_frame_image: torch.Tensor=None, # used for ImageToVideo - prompt_optimizer=True, - duration=6, - resolution="768P", - model="MiniMax-Hailuo-02", - unique_id: Union[str, None]=None, - **kwargs, - ): if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") @@ -408,7 +435,7 @@ class MinimaxHailuoVideoNode: # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0] + image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] video_generate_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -426,7 +453,7 @@ class MinimaxHailuoVideoNode: duration=duration, resolution=resolution, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await video_generate_operation.execute() @@ -447,8 +474,8 @@ class MinimaxHailuoVideoNode: failed_statuses=["Fail"], status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth, ) task_result = await video_generate_operation.execute() @@ -464,7 +491,7 @@ class MinimaxHailuoVideoNode: query_params={"file_id": int(file_id)}, ), request=EmptyRequest(), - auth_kwargs=kwargs, + auth_kwargs=auth, ) file_result = await file_retrieve_operation.execute() @@ -474,34 +501,31 @@ class MinimaxHailuoVideoNode: f"No video was found in the response. Full response: {file_result.model_dump()}" ) logging.info(f"Generated video URL: {file_url}") - if unique_id: + if cls.hidden.unique_id: if hasattr(file_result.file, "backup_download_url"): message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" else: message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, unique_id) + PromptServer.instance.send_progress_text(message, cls.hidden.unique_id) video_io = await download_url_to_bytesio(file_url) if video_io is None: error_msg = f"Failed to download video from {file_url}" logging.error(error_msg) raise Exception(error_msg) - return (VideoFromFile(video_io),) + return comfy_io.NodeOutput(VideoFromFile(video_io)) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "MinimaxTextToVideoNode": MinimaxTextToVideoNode, - "MinimaxImageToVideoNode": MinimaxImageToVideoNode, - # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, - "MinimaxHailuoVideoNode": MinimaxHailuoVideoNode, -} +class MinimaxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + MinimaxTextToVideoNode, + MinimaxImageToVideoNode, + # MinimaxSubjectToVideoNode, + MinimaxHailuoVideoNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "MinimaxTextToVideoNode": "MiniMax Text to Video", - "MinimaxImageToVideoNode": "MiniMax Image to Video", - "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", - "MinimaxHailuoVideoNode": "MiniMax Hailuo Video", -} + +async def comfy_entrypoint() -> MinimaxExtension: + return MinimaxExtension() From f9d2e4b742af9aea3c9ffa822397c1b86cec9304 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:38:12 +0300 Subject: [PATCH 196/325] convert WanCameraEmbedding node to V3 schema (#9714) --- comfy_extras/nodes_camera_trajectory.py | 81 ++++++++++++++++--------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 5e0e39f91..eb7ef363c 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -2,12 +2,12 @@ import nodes import torch import numpy as np from einops import rearrange +from typing_extensions import override import comfy.model_management +from comfy_api.latest import ComfyExtension, io -MAX_RESOLUTION = nodes.MAX_RESOLUTION - CAMERA_DICT = { "base_T_norm": 1.5, "base_angle": np.pi/3, @@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81): RT = np.stack(RT) return RT -class WanCameraEmbedding: +class WanCameraEmbedding(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), - "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), - }, - "optional":{ - "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), - "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - } + def define_schema(cls): + return io.Schema( + node_id="WanCameraEmbedding", + category="camera", + inputs=[ + io.Combo.Input( + "camera_pose", + options=[ + "Static", + "Pan Up", + "Pan Down", + "Pan Left", + "Pan Right", + "Zoom In", + "Zoom Out", + "Anti Clockwise (ACW)", + "ClockWise (CW)", + ], + default="Static", + ), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), + ], + outputs=[ + io.WanCameraEmbedding.Output(display_name="camera_embedding"), + io.Int.Output(display_name="width"), + io.Int.Output(display_name="height"), + io.Int.Output(display_name="length"), + ], + ) - } - - RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") - RETURN_NAMES = ("camera_embedding","width","height","length") - FUNCTION = "run" - CATEGORY = "camera" - - def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + @classmethod + def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: """ Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py @@ -210,9 +225,15 @@ class WanCameraEmbedding: control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) - return (control_camera_video, width, height, length) + return io.NodeOutput(control_camera_video, width, height, length) -NODE_CLASS_MAPPINGS = { - "WanCameraEmbedding": WanCameraEmbedding, -} +class CameraTrajectoryExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanCameraEmbedding, + ] + +async def comfy_entrypoint() -> CameraTrajectoryExtension: + return CameraTrajectoryExtension() From dcb883498337bcb2758fa9e7b326ea3b63c6b8d4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:38:46 +0300 Subject: [PATCH 197/325] convert Cosmos nodes to V3 schema (#9721) --- comfy_extras/nodes_cosmos.py | 129 +++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 4f4960551..7dd129d19 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -1,25 +1,32 @@ +from typing_extensions import override import nodes import torch import comfy.model_management import comfy.utils import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io -class EmptyCosmosLatentVideo: + +class EmptyCosmosLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyCosmosLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent.Output()], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) def vae_encode_with_padding(vae, image, width, height, length, padding=0): @@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0): return latent_temp[:, :, :latent_len] -class CosmosImageToVideoLatent: +class CosmosImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -74,33 +81,33 @@ class CosmosImageToVideoLatent: out_latent = {} out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -class CosmosPredict2ImageToVideoLatent: +class CosmosPredict2ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosPredict2ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -119,10 +126,18 @@ class CosmosPredict2ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, - "CosmosImageToVideoLatent": CosmosImageToVideoLatent, - "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, -} + +class CosmosExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyCosmosLatentVideo, + CosmosImageToVideoLatent, + CosmosPredict2ImageToVideoLatent, + ] + + +async def comfy_entrypoint() -> CosmosExtension: + return CosmosExtension() From ba68e83f1c103eb4cb57fe01328706a0574fff3c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:39:30 +0300 Subject: [PATCH 198/325] convert nodes_cond.py to V3 schema (#9719) --- comfy_extras/nodes_cond.py | 75 ++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 58c16f621..8b06e3de9 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -1,15 +1,25 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodeControlnet: +class CLIPTextEncodeControlnet(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPTextEncodeControlnet", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("conditioning"), + io.String.Input("text", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - - def encode(self, clip, conditioning, text): + @classmethod + def execute(cls, clip, conditioning, text) -> io.NodeOutput: tokens = clip.tokenize(text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) c = [] @@ -18,32 +28,41 @@ class CLIPTextEncodeControlnet: n[1]['cross_attn_controlnet'] = cond n[1]['pooled_output_controlnet'] = pooled c.append(n) - return (c, ) + return io.NodeOutput(c) -class T5TokenizerOptions: +class T5TokenizerOptions(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "clip": ("CLIP", ), - "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - } - } + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="T5TokenizerOptions", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - RETURN_TYPES = ("CLIP",) - FUNCTION = "set_options" - - def set_options(self, clip, min_padding, min_length): + @classmethod + def execute(cls, clip, min_padding, min_length) -> io.NodeOutput: clip = clip.clone() for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) - return (clip, ) + return io.NodeOutput(clip) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet, - "T5TokenizerOptions": T5TokenizerOptions, -} + +class CondExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeControlnet, + T5TokenizerOptions, + ] + + +async def comfy_entrypoint() -> CondExtension: + return CondExtension() From 53c9c7d39ad8a459e84a29e46a3e053154ef6013 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:39:55 +0300 Subject: [PATCH 199/325] convert CFG nodes to V3 schema (#9717) --- comfy_extras/nodes_cfg.py | 71 +++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 5abdc115a..4ebb4b51e 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -1,5 +1,10 @@ +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io + + # https://github.com/WeichenFan/CFG-Zero-star def optimized_scale(positive, negative): positive_flat = positive.reshape(positive.shape[0], -1) @@ -16,17 +21,20 @@ def optimized_scale(positive, negative): return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) -class CFGZeroStar: +class CFGZeroStar(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "advanced/guidance" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGZeroStar", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output(display_name="patched_model")], + ) - def patch(self, model): + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def cfg_zero_star(args): guidance_scale = args['cond_scale'] @@ -38,21 +46,24 @@ class CFGZeroStar: return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) m.set_model_sampler_post_cfg_function(cfg_zero_star) - return (m, ) + return io.NodeOutput(m) -class CFGNorm: +class CFGNorm(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "advanced/guidance" - EXPERIMENTAL = True + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGNorm", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[io.Model.Output(display_name="patched_model")], + is_experimental=True, + ) - def patch(self, model, strength): + @classmethod + def execute(cls, model, strength) -> io.NodeOutput: m = model.clone() def cfg_norm(args): cond_p = args['cond_denoised'] @@ -64,9 +75,17 @@ class CFGNorm: return pred_text_ * scale * strength m.set_model_sampler_post_cfg_function(cfg_norm) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "CFGZeroStar": CFGZeroStar, - "CFGNorm": CFGNorm, -} + +class CfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CFGZeroStar, + CFGNorm, + ] + + +async def comfy_entrypoint() -> CfgExtension: + return CfgExtension() From af99928f2218fc240dcfb3688ec47317ca146a78 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:40:34 +0300 Subject: [PATCH 200/325] convert Canny node to V3 schema (#9743) --- comfy_extras/nodes_canny.py | 46 +++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index d85e6b856..576f3640a 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -1,25 +1,41 @@ from kornia.filters import canny +from typing_extensions import override + import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class Canny: +class Canny(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}), - "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01}) - }} + def define_schema(cls): + return io.Schema( + node_id="Canny", + category="image/preprocessors", + inputs=[ + io.Image.Input("image"), + io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01), + io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01), + ], + outputs=[io.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "detect_edge" + @classmethod + def detect_edge(cls, image, low_threshold, high_threshold): + # Deprecated: use the V3 schema's `execute` method instead of this. + return cls.execute(image, low_threshold, high_threshold) - CATEGORY = "image/preprocessors" - - def detect_edge(self, image, low_threshold, high_threshold): + @classmethod + def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -NODE_CLASS_MAPPINGS = { - "Canny": Canny, -} + +class CannyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Canny] + + +async def comfy_entrypoint() -> CannyExtension: + return CannyExtension() From 581bae2af30b0839a39734bd97006c4009f9d70a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:41:26 +0300 Subject: [PATCH 201/325] convert Moonvalley API nodes to the V3 schema (#9698) --- comfy_api_nodes/nodes_moonvalley.py | 572 +++++++++++++++------------- 1 file changed, 298 insertions(+), 274 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 806a70e06..08e838fef 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,6 +1,7 @@ import logging from typing import Any, Callable, Optional, TypeVar import torch +from typing_extensions import override from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, validate_image_dimensions, @@ -26,11 +27,9 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, upload_video_to_comfyapi, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api.input.video_types import VideoInput -from comfy.comfy_types.node_typing import IO -from comfy_api.input_impl import VideoFromFile +from comfy_api.input import VideoInput +from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io import av import io @@ -362,7 +361,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: # Return as VideoFromFile using the buffer output_buffer.seek(0) - return VideoFromFile(output_buffer) + return InputImpl.VideoFromFile(output_buffer) except Exception as e: # Clean up on error @@ -373,166 +372,150 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: raise RuntimeError(f"Failed to trim video: {str(e)}") from e -# --- BaseMoonvalleyVideoNode --- -class BaseMoonvalleyVideoNode: - def parseWidthHeightFromRes(self, resolution: str): - # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict - res_map = { - "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, - "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, - "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, - "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, - } - if resolution in res_map: - return res_map[resolution] - else: - # Default to 1920x1080 if unknown - return {"width": 1920, "height": 1080} +def parse_width_height_from_res(resolution: str): + # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict + res_map = { + "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, + "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, + "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, + "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, + "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, + "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, + } + return res_map.get(resolution, {"width": 1920, "height": 1080}) - def parseControlParameter(self, value): - control_map = { - "Motion Transfer": "motion_control", - "Canny": "canny_control", - "Pose Transfer": "pose_control", - "Depth": "depth_control", - } - if value in control_map: - return control_map[value] - else: - return control_map["Motion Transfer"] - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, - ) +def parse_control_parameter(value): + control_map = { + "Motion Transfer": "motion_control", + "Canny": "canny_control", + "Pose Transfer": "pose_control", + "Depth": "depth_control", + } + return control_map.get(value, control_map["Motion Transfer"]) + + +async def get_response( + task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None +) -> MoonvalleyPromptResponse: + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{API_PROMPTS_ENDPOINT}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MoonvalleyPromptResponse, + ), + result_url_extractor=get_video_url_from_response, + node_id=node_id, + ) + + +class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyTextToVideoRequest, - "prompt_text", + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyImg2VideoNode", + display_name="Moonvalley Marey Image to Video", + category="api node/video/Moonvalley Marey", + description="Moonvalley Marey Image to Video Node", + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="The reference image used to generate the video", + ), + comfy_io.String.Input( + "prompt", multiline=True, ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyTextToVideoInferenceParams, + comfy_io.String.Input( "negative_prompt", multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", ), - "resolution": ( - IO.COMBO, - { - "options": [ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1440 x 1080)", - "3:4 (1080 x 1440)", - "21:9 (2560 x 1080)", - ], - "default": "16:9 (1920 x 1080)", - "tooltip": "Resolution of the output video", - }, + comfy_io.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", ), - "prompt_adherence": model_field_to_node_input( - IO.FLOAT, - MoonvalleyTextToVideoInferenceParams, - "guidance_scale", + comfy_io.Float.Input( + "prompt_adherence", default=10.0, - step=1, - min=1, - max=20, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", ), - "seed": model_field_to_node_input( - IO.INT, - MoonvalleyTextToVideoInferenceParams, + comfy_io.Int.Input( "seed", default=9, min=0, max=4294967295, step=1, - display="number", + display_mode=comfy_io.NumberDisplay.number, tooltip="Random seed value", ), - "steps": model_field_to_node_input( - IO.INT, - MoonvalleyTextToVideoInferenceParams, + comfy_io.Int.Input( "steps", default=100, min=1, max=100, + step=1, + tooltip="Number of denoising steps", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - "optional": { - "image": model_field_to_node_input( - IO.IMAGE, - MoonvalleyTextToVideoRequest, - "image_url", - tooltip="The reference image used to generate the video", - ), - }, - } - - RETURN_TYPES = ("STRING",) - FUNCTION = "generate" - CATEGORY = "api node/video/Moonvalley Marey" - API_NODE = True - - def generate(self, **kwargs): - return None - - -# --- MoonvalleyImg2VideoNode --- -class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls): - return super().INPUT_TYPES() - - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) - DESCRIPTION = "Moonvalley Marey Image to Video Node" - - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): - image = kwargs.get("image", None) - if image is None: - raise MoonvalleyApiError("image is required") - + async def execute( + cls, + image: torch.Tensor, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> comfy_io.NodeOutput: validate_input_image(image, True) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) + width_height = parse_width_height_from_res(resolution) + + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), + width=width_height["width"], + height=width_height["height"], use_negative_prompts=True, ) """Upload image to comfy backend to have a URL available for further processing""" @@ -541,7 +524,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): image_url = ( await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type + image, max_images=1, auth_kwargs=auth, mime_type=mime_type ) )[0] @@ -556,127 +539,102 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - return (video,) + return comfy_io.NodeOutput(video) -# --- MoonvalleyVid2VidNode --- -class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): - def __init__(self): - super().__init__() +class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyVideoToVideoRequest, - "prompt_text", + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyVideo2VideoNode", + display_name="Moonvalley Marey Video to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + comfy_io.String.Input( + "prompt", multiline=True, + tooltip="Describes the video to generate", ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyVideoToVideoInferenceParams, + comfy_io.String.Input( "negative_prompt", multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", ), - "seed": model_field_to_node_input( - IO.INT, - MoonvalleyVideoToVideoInferenceParams, + comfy_io.Int.Input( "seed", default=9, min=0, max=4294967295, step=1, - display="number", + display_mode=comfy_io.NumberDisplay.number, tooltip="Random seed value", control_after_generate=False, ), - "prompt_adherence": model_field_to_node_input( - IO.FLOAT, - MoonvalleyVideoToVideoInferenceParams, - "guidance_scale", - default=10.0, + comfy_io.Video.Input( + "video", + tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " + "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + ), + comfy_io.Combo.Input( + "control_type", + options=["Motion Transfer", "Pose Transfer"], + default="Motion Transfer", + optional=True, + ), + comfy_io.Int.Input( + "motion_intensity", + default=100, + min=0, + max=100, step=1, - min=1, - max=20, + tooltip="Only used if control_type is 'Motion Transfer'", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - "optional": { - "video": ( - IO.VIDEO, - { - "default": "", - "multiline": False, - "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", - }, - ), - "control_type": ( - ["Motion Transfer", "Pose Transfer"], - {"default": "Motion Transfer"}, - ), - "motion_intensity": ( - "INT", - { - "default": 100, - "step": 1, - "min": 0, - "max": 100, - "tooltip": "Only used if control_type is 'Motion Transfer'", - }, - ), - "image": model_field_to_node_input( - IO.IMAGE, - MoonvalleyTextToVideoRequest, - "image_url", - tooltip="The reference image used to generate the video", - ), - }, + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + seed: int, + video: Optional[VideoInput] = None, + control_type: str = "Motion Transfer", + motion_intensity: Optional[int] = 100, + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) - - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): - video = kwargs.get("video") - image = kwargs.get("image", None) - - if not video: - raise MoonvalleyApiError("video is required") - - video_url = "" - if video: - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi( - validated_video, auth_kwargs=kwargs - ) - mime_type = "image/png" - - if not image is None: - validate_input_image(image, with_frame_conditioning=True) - image_url = await upload_images_to_comfyapi( - image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type - ) - control_type = kwargs.get("control_type") - motion_intensity = kwargs.get("motion_intensity") + validated_video = validate_video_to_video_input(video) + video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) """Validate prompts and inference input""" validate_prompts(prompt, negative_prompt) @@ -688,11 +646,11 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): inference_params = MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, - seed=kwargs.get("seed"), + seed=seed, control_params=control_params, ) - control = self.parseControlParameter(control_type) + control = parse_control_parameter(control_type) request = MoonvalleyVideoToVideoRequest( control_type=control, @@ -700,7 +658,6 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): prompt_text=prompt, inference_params=inference_params, ) - request.image_url = image_url if not image is None else None initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -710,58 +667,125 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - - return (video,) + return comfy_io.NodeOutput(video) -# --- MoonvalleyTxt2VideoNode --- -class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): - def __init__(self): - super().__init__() - - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) +class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - input_types = super().INPUT_TYPES() - # Remove image-specific parameters - for param in ["image"]: - if param in input_types["optional"]: - del input_types["optional"][param] - return input_types + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyTxt2VideoNode", + display_name="Moonvalley Marey Text to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", + ), + comfy_io.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", + ), + comfy_io.Float.Input( + "prompt_adherence", + default=10.0, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", + ), + comfy_io.Int.Input( + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed value", + ), + comfy_io.Int.Input( + "steps", + default=100, + min=1, + max=100, + step=1, + tooltip="Inference steps", + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> comfy_io.NodeOutput: validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) + width_height = parse_width_height_from_res(resolution) + + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), + width=width_height["width"], + height=width_height["height"], ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) - initial_operation = SynchronousOperation( + init_op = SynchronousOperation( endpoint=ApiEndpoint( path=API_TXT2VIDEO_ENDPOINT, method=HttpMethod.POST, @@ -769,29 +793,29 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() + task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - return (video,) + return comfy_io.NodeOutput(video) -NODE_CLASS_MAPPINGS = { - "MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode, - "MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode, - "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, -} +class MoonvalleyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + MoonvalleyImg2VideoNode, + MoonvalleyTxt2VideoNode, + MoonvalleyVideo2VideoNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video", - "MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video", - "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", -} +async def comfy_entrypoint() -> MoonvalleyExtension: + return MoonvalleyExtension() From b149e2e1e302e75ce5b47e9b823b42b304d70b4b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 14:53:15 -0700 Subject: [PATCH 202/325] Better way of doing the generator for the hunyuan image noise aug. (#9834) --- comfy/model_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4176bca25..324d89cff 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1448,7 +1448,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) if noise_augmentation > 0: - noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image else: image = 0.75 * image From d7f40442f91a02946cab7445c6204bf154b1e86f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 12 Sep 2025 15:07:38 -0700 Subject: [PATCH 203/325] Enable Runtime Selection of Attention Functions (#9639) * Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later --- comfy/ldm/ace/attention.py | 9 +- comfy/ldm/ace/model.py | 4 + comfy/ldm/audio/dit.py | 25 ++-- comfy/ldm/aura/mmdit.py | 29 ++--- comfy/ldm/cascade/common.py | 12 +- comfy/ldm/cascade/stage_b.py | 14 +-- comfy/ldm/cascade/stage_c.py | 14 +-- comfy/ldm/chroma/layers.py | 8 +- comfy/ldm/chroma/model.py | 17 ++- comfy/ldm/cosmos/blocks.py | 10 +- comfy/ldm/cosmos/model.py | 2 + comfy/ldm/cosmos/predict2.py | 17 ++- comfy/ldm/flux/layers.py | 10 +- comfy/ldm/flux/math.py | 4 +- comfy/ldm/flux/model.py | 17 ++- .../genmo/joint_model/asymm_models_joint.py | 11 +- comfy/ldm/hidream/model.py | 18 ++- comfy/ldm/hunyuan3d/model.py | 17 ++- comfy/ldm/hunyuan_video/model.py | 25 ++-- comfy/ldm/lightricks/model.py | 19 +-- comfy/ldm/lumina/model.py | 17 ++- comfy/ldm/modules/attention.py | 114 +++++++++++++----- comfy/ldm/modules/diffusionmodules/mmdit.py | 9 +- comfy/ldm/omnigen/omnigen2.py | 23 ++-- comfy/ldm/qwen_image/model.py | 12 +- comfy/ldm/wan/model.py | 38 +++--- 26 files changed, 316 insertions(+), 179 deletions(-) diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index f20a01669..670eb9783 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -133,6 +133,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + transformer_options={}, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( @@ -140,6 +141,7 @@ class Attention(nn.Module): hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + transformer_options=transformer_options, **cross_attention_kwargs, ) @@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0: encoder_attention_mask: Optional[torch.FloatTensor] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, + transformer_options={}, *args, **kwargs, ) -> torch.Tensor: @@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = optimized_attention( - query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options, ).to(query.dtype) # linear proj @@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module): rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, + transformer_options={}, ): N = hidden_states.shape[0] @@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) else: attn_output, _ = self.attn( @@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, + transformer_options=transformer_options, ) if self.use_adaln_single: @@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) hidden_states = attn_output + hidden_states diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 41d85eeb5..399329853 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + transformer_options={}, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module): rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis, temb=temb, + transformer_options=transformer_options, ) output = self.final_layer(hidden_states, embedded_timestep, output_length) @@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length = hidden_states.shape[-1] + transformer_options = kwargs.get("transformer_options", {}) output = self.decode( hidden_states=hidden_states, attention_mask=attention_mask, @@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length=output_length, block_controlnet_hidden_states=block_controlnet_hidden_states, controlnet_scale=controlnet_scale, + transformer_options=transformer_options, ) return output diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index d0d69bbdc..ca865189e 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ class Attention(nn.Module): mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ class TransformerBlock(nn.Module): global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ class TransformerBlock(nn.Module): residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ class TransformerBlock(nn.Module): x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index d7f32b5e8..66d9613b6 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -85,7 +85,7 @@ class SingleAttention(nn.Module): ) #@torch.compile() - def forward(self, c): + def forward(self, c, transformer_options={}): bsz, seqlen1, _ = c.shape @@ -95,7 +95,7 @@ class SingleAttention(nn.Module): v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) q, k = self.q_norm1(q), self.k_norm1(k) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c = self.w1o(output) return c @@ -144,7 +144,7 @@ class DoubleAttention(nn.Module): #@torch.compile() - def forward(self, c, x): + def forward(self, c, x, transformer_options={}): bsz, seqlen1, _ = c.shape bsz, seqlen2, _ = x.shape @@ -168,7 +168,7 @@ class DoubleAttention(nn.Module): torch.cat([cv, xv], dim=1), ) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) @@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module): self.is_last = is_last #@torch.compile() - def forward(self, c, x, global_cond, **kwargs): + def forward(self, c, x, global_cond, transformer_options={}, **kwargs): cres, xres = c, x @@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) + c, x = self.attn(c, x, transformer_options=transformer_options) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) @@ -255,13 +255,13 @@ class DiTBlock(nn.Module): self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) #@torch.compile() - def forward(self, cx, global_cond, **kwargs): + def forward(self, cx, global_cond, transformer_options={}, **kwargs): cxres = cx shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( global_cond ).chunk(6, dim=1) cx = modulate(self.norm1(cx), shift_msa, scale_msa) - cx = self.attn(cx) + cx = self.attn(cx, transformer_options=transformer_options) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout @@ -473,13 +473,14 @@ class MMDiT(nn.Module): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], - args["vec"]) + args["vec"], + transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) c = out["txt"] x = out["img"] else: - c, x = layer(c, x, global_cond, **kwargs) + c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) @@ -488,13 +489,13 @@ class MMDiT(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], args["vec"]) + out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) cx = out["img"] else: - cx = layer(cx, global_cond, **kwargs) + cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs) x = cx[:, c_len:] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c821..42ef98c7a 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module): self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) - def forward(self, q, k, v): + def forward(self, q, k, v, transformer_options={}): q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options) return self.out_proj(out) @@ -47,13 +47,13 @@ class Attention2D(nn.Module): self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv, self_attn=False, transformer_options={}): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] - x = self.attn(x, kv, kv) + x = self.attn(x, kv, kv, transformer_options=transformer_options) x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -114,9 +114,9 @@ class AttnBlock(nn.Module): operations.Linear(c_cond, c, dtype=dtype, device=device) ) - def forward(self, x, kv): + def forward(self, x, kv, transformer_options={}): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options) return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 773830956..428c67fdf 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -173,7 +173,7 @@ class StageB(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip): + def _down_encode(self, x, r_embed, clip, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -187,7 +187,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -199,7 +199,7 @@ class StageB(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip): + def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -216,7 +216,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -228,7 +228,7 @@ class StageB(nn.Module): x = upscaler(x) return x - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) @@ -245,8 +245,8 @@ class StageB(nn.Module): nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) + level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d0349..ebc4434e2 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -182,7 +182,7 @@ class StageC(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip, cnet=None): + def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -201,7 +201,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -213,7 +213,7 @@ class StageC(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -235,7 +235,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -247,7 +247,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -262,8 +262,8 @@ class StageC(nn.Module): # Model Blocks x = self.embedding(x) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2a0dec606..fc7110cce 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention @@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: mod = vec x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x.addcmul_(mod.gate, output) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 5cff44dc8..4f709f87d 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -193,14 +193,16 @@ class Chroma(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": double_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -209,7 +211,8 @@ class Chroma(nn.Module): txt=txt, vec=double_mod, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -229,17 +232,19 @@ class Chroma(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": single_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..afb43d469 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ class Attention(nn.Module): context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ class VideoAttn(nn.Module): context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ class VideoAttn(nn.Module): context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module): adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module): context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 53698b758..52ef7ef43 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -520,6 +520,7 @@ class GeneralDIT(nn.Module): x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -534,6 +535,7 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index fcc83ba76..07a4fc79f 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module): return x -def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: """Computes multi-head attention using PyTorch's native implementation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation. @@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) class Attention(nn.Module): @@ -180,8 +180,8 @@ class Attention(nn.Module): return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] return self.output_dropout(self.output_proj(result)) def forward( @@ -189,6 +189,7 @@ class Attention(nn.Module): x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Args: @@ -196,7 +197,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + return self.compute_attention(q, k, v, transformer_options=transformer_options) class Timesteps(nn.Module): @@ -459,6 +460,7 @@ class Block(nn.Module): rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,6 +514,7 @@ class Block(nn.Module): rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -525,6 +528,7 @@ class Block(nn.Module): layer_norm_cross_attn: Callable, _scale_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: _normalized_x_B_T_H_W_D = _fn( _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D @@ -534,6 +538,7 @@ class Block(nn.Module): rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -547,6 +552,7 @@ class Block(nn.Module): self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, ) x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module): "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), } for block in self.blocks: x_B_T_H_W_D = block( diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 113eb2096..ef21b416b 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((img_q, txt_q), dim=2), torch.cat((img_k, txt_k), dim=2), torch.cat((img_v, txt_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: @@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] @@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e0978176..4d743cda2 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: q_shape = q.shape k_shape = k.shape @@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8ea7d4f57..14f90cea5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -144,14 +144,16 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -160,7 +162,8 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -181,17 +184,19 @@ class Flux(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 366a8b713..5c1bb4d42 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module): scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. crop_y, + transformer_options={}, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") @@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module): xy = optimized_attention(q, k, - v, self.num_heads, skip_reshape=True) + v, self.num_heads, skip_reshape=True, transformer_options=transformer_options) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x = self.proj_x(x) @@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module): x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + transformer_options={}, **attn_kwargs, ): """Forward pass of a block. @@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module): y, scale_x=scale_msa_x, scale_y=scale_msa_y, + transformer_options=transformer_options, **attn_kwargs, ) @@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module): args["txt"], rope_cos=args["rope_cos"], rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] + crop_y=args["num_tokens"], + transformer_options=args["transformer_options"] ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] else: @@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module): rope_cos=rope_cos, rope_sin=rope_sin, crop_y=num_tokens, + transformer_options=transformer_options, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index ae49cf945..28d81c79e 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module): return t_emb -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options) class HiDreamAttnProcessor_flashattn: @@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn: image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, *args, **kwargs, ) -> torch.FloatTensor: @@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn: query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) - hidden_states = attention(query, key, value) + hidden_states = attention(query, key, value, transformer_options=transformer_options) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) @@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.Tensor: return self.processor( self, @@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) @@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, - + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ @@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): norm_image_tokens, image_tokens_masks, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ @@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module): image_tokens_masks, norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: return self.block( image_tokens, @@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module): text_tokens, adaln_input, rope, + transformer_options=transformer_options, ) @@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, + transformer_options=transformer_options, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens=None, adaln_input=adaln_input, rope=rope, + transformer_options=transformer_options, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 0fa5e78c1..4991b1645 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) img = torch.cat((txt, img), 1) @@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = img[:, txt.shape[1]:, ...] img = self.final_layer(img, vec) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index ca86b8bb1..5132e6c07 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -80,13 +80,13 @@ class TokenRefinerBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn.qkv(norm_x) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) - attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) + attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) @@ -117,14 +117,14 @@ class IndividualTokenRefiner(nn.Module): ] ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): m = None if mask is not None: m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = m + m.transpose(2, 3) for block in self.blocks: - x = block(x, c, m) + x = block(x, c, m, transformer_options=transformer_options) return x @@ -152,6 +152,7 @@ class TokenRefiner(nn.Module): x, timesteps, mask, + transformer_options={}, ): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) @@ -160,7 +161,7 @@ class TokenRefiner(nn.Module): c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options) return x @@ -328,7 +329,7 @@ class HunyuanVideo(nn.Module): if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - txt = self.txt_in(txt, timesteps, txt_mask) + txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) @@ -352,14 +353,14 @@ class HunyuanVideo(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -374,13 +375,13 @@ class HunyuanVideo(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index aa2ea62b1..def365ba7 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -271,7 +271,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, mask=None, pe=None): + def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -285,9 +285,9 @@ class CrossAttention(nn.Module): k = apply_rotary_emb(k, pe) if mask is None: - out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa - x += self.attn2(x, context=context, mask=attention_mask) + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp x += self.ff(y) * gate_mlp @@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module): context=context, attention_mask=attention_mask, timestep=timestep, - pe=pe + pe=pe, + transformer_options=transformer_options, ) # 3. Output diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e08ed817d..f87d98ac0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -104,6 +104,7 @@ class JointAttention(nn.Module): x: torch.Tensor, x_mask: torch.Tensor, freqs_cis: torch.Tensor, + transformer_options={}, ) -> torch.Tensor: """ @@ -140,7 +141,7 @@ class JointAttention(nn.Module): if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + transformer_options={}, ): """ Perform a forward pass through the TransformerBlock. @@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module): modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( @@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module): self.attention_norm1(x), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + self.ffn_norm2( @@ -494,7 +498,7 @@ class NextDiT(nn.Module): return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: bsz = len(x) pH = pW = self.patch_size @@ -554,7 +558,7 @@ class NextDiT(nn.Module): # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) # refine image flat_x = [] @@ -573,7 +577,7 @@ class NextDiT(nn.Module): padded_img_embed = self.x_embedder(padded_img_embed) padded_img_mask = padded_img_mask.unsqueeze(1) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) if cap_mask is not None: mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) @@ -616,12 +620,13 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) x = self.final_layer(x, adaln_input) x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28df..bf2553c37 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,9 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -17,23 +18,45 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name == "optimized": + return optimized_attention + elif name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -91,7 +114,27 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + remove_attn_wrapper_key = False + try: + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] + return wrapper + +@wrap_attn +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -359,7 +403,8 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -427,8 +472,8 @@ else: #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -534,8 +579,8 @@ except AttributeError as error: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -597,6 +642,19 @@ else: optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): @@ -629,7 +687,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +698,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer): B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options) x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 4d6beba2d..42f406f1a 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): return _block_mixing(*args, **kwargs) -def _block_mixing(context, x, context_block, x_block, c): +def _block_mixing(context, x, context_block, x_block, c, transformer_options={}): context_qkv, context_intermediates = context_block.pre_attention(context, c) if x_block.x_block_self_attn: @@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn = optimized_attention( qkv[0], qkv[1], qkv[2], heads=x_block.attn.num_heads, + transformer_options=transformer_options, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn2 = optimized_attention( x_qkv2[0], x_qkv2[1], x_qkv2[2], heads=x_block.attn2.num_heads, + transformer_options=transformer_options, ) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) else: @@ -958,10 +960,10 @@ class MMDiT(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap}) context = out["txt"] x = out["img"] else: @@ -970,6 +972,7 @@ class MMDiT(nn.Module): x, c=c_mod, use_checkpoint=self.use_checkpoint, + transformer_options=transformer_options, ) if control is not None: control_o = control.get("output") diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 4884449f8..82edc92da 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -120,7 +120,7 @@ class Attention(nn.Module): nn.Dropout(0.0) ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) @@ -146,7 +146,7 @@ class Attention(nn.Module): key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) - hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) + hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) hidden_states = self.to_out[0](hidden_states) return hidden_states @@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module): self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) @@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module): ref_img_sizes, img_sizes, ) - def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): + def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) @@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module): shift += ref_img_len for layer in self.noise_refiner: - hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: - ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) + ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states - def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape @@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module): ) for layer in self.context_refiner: - text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( @@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module): noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, + transformer_options=transformer_options, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 04071f31c..b9f60c2b7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -132,6 +132,7 @@ class Attention(nn.Module): encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: seq_txt = encoder_hidden_states.shape[1] @@ -159,7 +160,7 @@ class Attention(nn.Module): joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) txt_mod_params = self.txt_mod(temb) @@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) hidden_states = hidden_states + img_gate1 * img_attn_output @@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 47857dc2b..63472ada2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module): self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module): k.view(b, s, n * d), v, heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -116,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_img_len): + def forward(self, x, context, context_img_len, transformer_options={}): r""" Args: x(Tensor): Shape [B, L1, C] @@ -131,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention): v = self.v(context) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) + img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) # output x = x + img_x @@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module): freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), - freqs) + freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -559,12 +561,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) @@ -742,17 +744,17 @@ class VaceWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): - c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) x += c_skip * vace_strength[iii] del c_skip # head @@ -841,12 +843,12 @@ class CameraWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) From a3b04de7004cc19dee9364bd71e62bab05475810 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:46:46 -0700 Subject: [PATCH 204/325] Hunyuan refiner vae now works with tiled. (#9836) --- comfy/ldm/hunyuan_video/vae_refiner.py | 1 - comfy/sd.py | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index e3fff9bbe..c6f742710 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -185,7 +185,6 @@ class Encoder(nn.Module): self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): - x = x.unsqueeze(2) x = self.conv_in(x) for stage in self.down: diff --git a/comfy/sd.py b/comfy/sd.py index 02ddc7239..f8f1a89e8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -412,9 +412,12 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.downscale_ratio = 16 - self.upscale_ratio = 16 + ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.latent_channels = 64 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 self.not_video = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -684,8 +687,11 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -719,7 +725,10 @@ class VAE: dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) if dims == 3: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) From 2559dee49202365bc97218b98121e796f57dfcb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 13 Sep 2025 04:52:58 +0300 Subject: [PATCH 205/325] Support wav2vec base models (#9637) * Support wav2vec base models * trim trailing whitespace * Do interpolation after --- comfy/audio_encoders/audio_encoders.py | 36 ++++++++++- comfy/audio_encoders/wav2vec2.py | 87 +++++++++++++++++++------- 2 files changed, 99 insertions(+), 24 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 538c21bd5..d1ec78f69 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -11,7 +11,13 @@ class AudioEncoderModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) + model_config = dict(config) + model_config.update({ + "dtype": self.dtype, + "device": offload_device, + "operations": comfy.ops.manual_cast + }) + self.model = Wav2Vec2Model(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -25,7 +31,7 @@ class AudioEncoderModel(): def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) - out, all_layers = self.model(audio.to(self.load_device)) + out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers @@ -33,8 +39,32 @@ class AudioEncoderModel(): def load_audio_encoder_from_sd(sd, prefix=""): - audio_encoder = AudioEncoderModel(None) sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False + } + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + + audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index de906622a..ef10dcd2a 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -13,19 +13,49 @@ class LayerNormConv(nn.Module): x = self.conv(x) return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) +class LayerGroupNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x)) + +class ConvNoNorm(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(x) + class ConvFeatureEncoder(nn.Module): - def __init__(self, conv_dim, dtype=None, device=None, operations=None): + def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None): super().__init__() - self.conv_layers = nn.ModuleList([ - LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - ]) + if conv_norm: + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) + else: + self.conv_layers = nn.ModuleList([ + LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) def forward(self, x): x = x.unsqueeze(1) @@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module): num_heads=12, num_layers=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module): embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) for _ in range(num_layers) ]) self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): x = x + self.pos_conv_embed(x) all_x = () + if not self.do_stable_layer_norm: + x = self.layer_norm(x) for layer in self.layers: all_x += (x,) x = layer(x, mask) - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) all_x += (x,) return x, all_x @@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module): embed_dim=768, num_heads=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module): self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): residual = x - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) x = self.attention(x, mask=mask) x = residual + x - - x = x + self.feed_forward(self.final_layer_norm(x)) - return x + if not self.do_stable_layer_norm: + x = self.layer_norm(x) + return self.final_layer_norm(x + self.feed_forward(x)) + else: + return x + self.feed_forward(self.final_layer_norm(x)) class Wav2Vec2Model(nn.Module): @@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module): final_dim=256, num_heads=16, num_layers=24, + conv_norm=True, + conv_bias=True, + do_normalize=True, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() conv_dim = 512 - self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) + self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations) self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + self.do_normalize = do_normalize self.encoder = TransformerEncoder( embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) - def forward(self, x, mask_time_indices=None, return_dict=False): - + def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False): x = torch.mean(x, dim=1) - x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + if self.do_normalize: + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) features = self.feature_extractor(x) features = self.feature_projection(features) - batch_size, seq_len, _ = features.shape x, all_x = self.encoder(features) - return x, all_x From 29bf807b0e2d89402d555d08bd8e9df15e636f0c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 18:57:04 -0700 Subject: [PATCH 206/325] Cleanup. (#9838) --- comfy/audio_encoders/audio_encoders.py | 2 +- comfy/audio_encoders/wav2vec2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index d1ec78f69..6fb5b08e9 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -31,7 +31,7 @@ class AudioEncoderModel(): def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) - out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate) + out, all_layers = self.model(audio.to(self.load_device)) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index ef10dcd2a..4e34a40a7 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -238,7 +238,7 @@ class Wav2Vec2Model(nn.Module): device=device, dtype=dtype, operations=operations ) - def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False): + def forward(self, x, mask_time_indices=None, return_dict=False): x = torch.mean(x, dim=1) if self.do_normalize: From e5e70636e7b7b54695220a88ab036c1607959736 Mon Sep 17 00:00:00 2001 From: Kimbing Ng <50580578+KimbingNg@users.noreply.github.com> Date: Sun, 14 Sep 2025 04:59:19 +0800 Subject: [PATCH 207/325] Remove single quote pattern to avoid wrong matches (#9842) --- comfy/text_encoders/hunyuan_image.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index be396cae7..699eddc33 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -22,17 +22,14 @@ class HunyuanImageTokenizer(QwenImageTokenizer): # ByT5 processing for HunyuanImage text_prompt_texts = [] - pattern_quote_single = r'\'(.*?)\'' pattern_quote_double = r'\"(.*?)\"' pattern_quote_chinese_single = r'‘(.*?)’' pattern_quote_chinese_double = r'“(.*?)”' - matches_quote_single = re.findall(pattern_quote_single, text) matches_quote_double = re.findall(pattern_quote_double, text) matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) - text_prompt_texts.extend(matches_quote_single) text_prompt_texts.extend(matches_quote_double) text_prompt_texts.extend(matches_quote_chinese_single) text_prompt_texts.extend(matches_quote_chinese_double) From c1297f4eb38a63e2f99c9fa76e32e3a36c933b85 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:58:43 -0600 Subject: [PATCH 208/325] Add support for Chroma Radiance (#9682) * Initial Chroma Radiance support * Minor Chroma Radiance cleanups * Update Radiance nodes to ensure latents/images are on the intermediate device * Fix Chroma Radiance memory estimation. * Increase Chroma Radiance memory usage factor * Increase Chroma Radiance memory usage factor once again * Ensure images are multiples of 16 for Chroma Radiance Add batch dimension and fix channels when necessary in ChromaRadianceImageToLatent node * Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor * Update Radiance to support conv nerf final head type. * Allow setting NeRF embedder dtype for Radiance Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe) * Add ChromaRadianceStubVAE node * Crop Radiance image inputs to multiples of 16 instead of erroring to be in line with existing VAE behavior * Convert Chroma Radiance nodes to V3 schema. * Add ChromaRadianceOptions node and backend support. Cleanups/refactoring to reduce code duplication with Chroma. * Fix overriding the NeRF embedder dtype for Chroma Radiance * Minor Chroma Radiance cleanups * Move Chroma Radiance to its own directory in ldm Minor code cleanups and tooltip improvements * Fix Chroma Radiance embedder dtype overriding * Remove Radiance dynamic nerf_embedder dtype override feature * Unbork Radiance NeRF embedder init * Remove Chroma Radiance image conversion and stub VAE nodes Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1 --- comfy/latent_formats.py | 17 ++ comfy/ldm/chroma/model.py | 10 +- comfy/ldm/chroma_radiance/layers.py | 206 ++++++++++++++++ comfy/ldm/chroma_radiance/model.py | 328 ++++++++++++++++++++++++++ comfy/model_base.py | 9 +- comfy/model_detection.py | 14 +- comfy/sd.py | 60 +++++ comfy/supported_models.py | 15 +- comfy_extras/nodes_chroma_radiance.py | 114 +++++++++ nodes.py | 6 +- 10 files changed, 770 insertions(+), 9 deletions(-) create mode 100644 comfy/ldm/chroma_radiance/layers.py create mode 100644 comfy/ldm/chroma_radiance/model.py create mode 100644 comfy_extras/nodes_chroma_radiance.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 894540879..77e642a94 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -629,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class ChromaRadiance(LatentFormat): + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 4f709f87d..ad1c523fe 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -151,8 +151,6 @@ class Chroma(nn.Module): attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) @@ -254,8 +252,9 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = self.get_modulations(mod_vectors, "final") - img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + if hasattr(self, "final_layer"): + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): @@ -271,6 +270,9 @@ class Chroma(nn.Module): img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py new file mode 100644 index 000000000..3c7bc9b6b --- /dev/null +++ b/comfy/ldm/chroma_radiance/layers.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/lodestone-rock/flow +from functools import lru_cache + +import torch +from torch import nn + +from comfy.ldm.flux.layers import RMSNorm + + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + dtype=None, + device=None, + operations=None, + ): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.dtype = dtype + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + input_dtype = inputs.dtype + inputs = inputs.to(dtype=self.dtype) + + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat((inputs, dct), dim=-1) + + # Project the combined tensor to the target hidden size. + return self.embedder(inputs).to(dtype=input_dtype) + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) + self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.mlp_ratio = mlp_ratio + + + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + return x + res_x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.conv = operations.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py new file mode 100644 index 000000000..f7eb7a22e --- /dev/null +++ b/comfy/ldm/chroma_radiance/model.py @@ -0,0 +1,328 @@ +# Credits: +# Original Flux code can be found on: https://github.com/black-forest-labs/flux +# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor, nn +from einops import repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import EmbedND + +from comfy.ldm.chroma.model import Chroma, ChromaParams +from comfy.ldm.chroma.layers import ( + DoubleStreamBlock, + SingleStreamBlock, + Approximator, +) +from .layers import ( + NerfEmbedder, + NerfGLUBlock, + NerfFinalLayer, + NerfFinalLayerConv, +) + + +@dataclass +class ChromaRadianceParams(ChromaParams): + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + # Setting nerf_tile_size to 0 disables tiling. + nerf_tile_size: int + # Currently one of linear (legacy) or conv. + nerf_final_head_type: str + # None means use the same dtype as the model. + nerf_embedder_dtype: Optional[torch.dtype] + + +class ChromaRadiance(Chroma): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + if operations is None: + raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") + nn.Module.__init__(self) + self.dtype = dtype + params = ChromaRadianceParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in_patch = operations.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True, + dtype=dtype, + device=device, + ) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs, + dtype=params.nerf_embedder_dtype or dtype, + device=device, + operations=operations, + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + dtype=dtype, + device=device, + operations=operations, + ) for _ in range(params.nerf_depth) + ]) + + if params.nerf_final_head_type == "linear": + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + elif params.nerf_final_head_type == "conv": + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + else: + errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" + raise ValueError(errstr) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + @property + def _nerf_final_layer(self) -> nn.Module: + if self.params.nerf_final_head_type == "linear": + return self.nerf_final_layer + if self.params.nerf_final_head_type == "conv": + return self.nerf_final_layer_conv + # Impossible to get here as we raise an error on unexpected types on initialization. + raise NotImplementedError + + def img_in(self, img: Tensor) -> Tensor: + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + # flatten into a sequence for the transformer. + return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + def forward_nerf( + self, + img_orig: Tensor, + img_out: Tensor, + params: ChromaRadianceParams, + ) -> Tensor: + B, C, H, W = img_orig.shape + num_patches = img_out.shape[1] + patch_size = params.patch_size + + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: + # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than + # the tile size. + img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + else: + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + + # Get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # Pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct = block(img_dct, nerf_hidden) + + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=patch_size, + stride=patch_size, + ) + return self._nerf_final_layer(img_dct) + + def forward_tiled_nerf( + self, + nerf_hidden: Tensor, + nerf_pixels: Tensor, + batch: int, + channels: int, + num_patches: int, + patch_size: int, + params: ChromaRadianceParams, + ) -> Tensor: + """ + Processes the NeRF head in tiles to save memory. + nerf_hidden has shape [B, L, D] + nerf_pixels has shape [B, L, C * P * P] + """ + tile_size = params.nerf_tile_size + output_tiles = [] + # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. + for i in range(0, num_patches, tile_size): + end = min(i + tile_size, num_patches) + + # Slice the current tile from the input tensors + nerf_hidden_tile = nerf_hidden[:, i:end, :] + nerf_pixels_tile = nerf_pixels[:, i:end, :] + + # Get the actual number of patches in this tile (can be smaller for the last tile) + num_patches_tile = nerf_hidden_tile.shape[1] + + # Reshape the tile for per-patch processing + # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] + nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) + # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] + nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + + # pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct_tile = block(img_dct_tile, nerf_hidden_tile) + + output_tiles.append(img_dct_tile) + + # Concatenate the processed tiles along the patch dimension + return torch.cat(output_tiles, dim=0) + + def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: + params = self.params + if not overrides: + return params + params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} + nullable_keys = frozenset(("nerf_embedder_dtype",)) + bad_keys = tuple(k for k in overrides if k not in params_dict) + if bad_keys: + e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + bad_keys = tuple( + k + for k, v in overrides.items() + if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + ) + if bad_keys: + e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + # At this point it's all valid keys and values so we can merge with the existing params. + params_dict |= overrides + return params.__class__(**params_dict) + + def _forward( + self, + x: Tensor, + timestep: Tensor, + context: Tensor, + guidance: Optional[Tensor], + control: Optional[dict]=None, + transformer_options: dict={}, + **kwargs: dict, + ) -> Tensor: + bs, c, h, w = x.shape + img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if context.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + + params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) + + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + + img_out = self.forward_orig( + img, + img_ids, + context, + txt_ids, + timestep, + guidance, + control, + transformer_options, + attn_mask=kwargs.get("attention_mask", None), + ) + return self.forward_nerf(img, img_out, params) diff --git a/comfy/model_base.py b/comfy/model_base.py index 324d89cff..252dfcf69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model +import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1320,8 +1321,8 @@ class HiDream(BaseModel): return out class Chroma(Flux): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1331,6 +1332,10 @@ class Chroma(Flux): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class ChromaRadiance(Chroma): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fe983cede..03d44f65e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + dit_config["image_model"] = "chroma_radiance" + dit_config["in_channels"] = 3 + dit_config["out_channels"] = 3 + dit_config["patch_size"] = 16 + dit_config["nerf_hidden_size"] = 64 + dit_config["nerf_mlp_ratio"] = 4 + dit_config["nerf_depth"] = 4 + dit_config["nerf_max_freqs"] = 8 + dit_config["nerf_tile_size"] = 32 + dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index f8f1a89e8..cb92802e9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -785,6 +785,66 @@ class VAE: except: return None +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE: + def __init__(self, size_increment: int=16): + self.intermediate_device = comfy.model_management.intermediate_device() + self.size_increment = size_increment + + def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: + if self.size_increment == 1: + return pixels + dims = pixels.shape[1:-1] + for d in range(len(dims)): + d_adj = (dims[d] // self.size_increment) * self.size_increment + if d_adj == d: + continue + d_offset = (dims[d] % self.size_increment) // 2 + pixels = pixels.narrow(d + 1, d_offset, d_adj) + return pixels + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + if pixels.ndim == 3: + pixels = pixels.unsqueeze(0) + elif pixels.ndim != 4: + raise ValueError("Unexpected input image shape") + # Ensure the image has spatial dimensions that are multiples of 16. + pixels = self.vae_encode_crop_pixels(pixels) + h, w, c = pixels.shape[1:] + if h < self.size_increment or w < self.size_increment: + raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") + pixels= pixels[..., :3] + if c == 1: + pixels = pixels.expand(-1, -1, -1, 3) + elif c != 3: + raise ValueError("Unexpected number of channels in input image") + # Rescale to -1..1 and move the channel dimension to position 1. + latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() + latent -= 0.5 + latent *= 2 + return latent.clamp_(-1, 1) + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + # Rescale to 0..1 and move the channel dimension to the end. + img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + img = img.clamp_(-1, 1).movedim(1, -1).contiguous() + img += 1.0 + img *= 0.5 + return img.clamp_(0, 1) + + encode_tiled = encode + decode_tiled = decode + + @classmethod + def spacial_compression_decode(cls) -> int: + # This just exists so the tiled VAE nodes don't crash. + return 1 + + spacial_compression_encode = spacial_compression_decode + temporal_compression_decode = spacial_compression_decode + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 472ea0ae9..be36b5dfe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1205,6 +1205,19 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class ChromaRadiance(Chroma): + unet_config = { + "image_model": "chroma_radiance", + } + + latent_format = comfy.latent_formats.ChromaRadiance + + # Pixel-space model, no spatial compression for model input. + memory_usage_factor = 0.0325 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ChromaRadiance(self, device=device) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1338,6 +1351,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 000000000..381989818 --- /dev/null +++ b/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,114 @@ +from typing_extensions import override +from typing import Callable + +import torch + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + +import nodes + +class EmptyChromaRadianceLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyChromaRadianceLatentImage", + category="latent/chroma_radiance", + inputs=[ + io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput: + latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + tooltip="First sigma that these options will be in effect.", + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + tooltip="Last sigma that these options will be in effect.", + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + + +class ChromaRadianceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyChromaRadianceLatentImage, + ChromaRadianceOptions, + ] + + +async def comfy_entrypoint() -> ChromaRadianceExtension: + return ChromaRadianceExtension() diff --git a/nodes.py b/nodes.py index 2befb4b75..76b8cbac8 100644 --- a/nodes.py +++ b/nodes.py @@ -730,6 +730,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") + vaes.append("chroma_radiance") return vaes @staticmethod @@ -772,7 +773,9 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + if vae_name == "chroma_radiance": + return (comfy.sd.PixelspaceConversionVAE(),) + elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) @@ -2322,6 +2325,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_chroma_radiance.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", From 80b7c9455bf7afba7a9e95a1eb76b172408ab56c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:03:34 -0700 Subject: [PATCH 209/325] Changes to the previous radiance commit. (#9851) --- comfy/ldm/chroma_radiance/model.py | 7 +-- comfy/pixel_space_convert.py | 16 +++++++ comfy/sd.py | 69 +++++------------------------- comfy/supported_models.py | 2 +- nodes.py | 7 +-- 5 files changed, 35 insertions(+), 66 deletions(-) create mode 100644 comfy/pixel_space_convert.py diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index f7eb7a22e..47aa11b04 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -306,8 +306,9 @@ class ChromaRadiance(Chroma): params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) + h_len = (img.shape[-2] // self.patch_size) + w_len = (img.shape[-1] // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) @@ -325,4 +326,4 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params) + return self.forward_nerf(img, img_out, params)[:, :, :h, :w] diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py new file mode 100644 index 000000000..049bbcfb4 --- /dev/null +++ b/comfy/pixel_space_convert.py @@ -0,0 +1,16 @@ +import torch + + +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE(torch.nn.Module): + def __init__(self): + super().__init__() + self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0)) + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return pixels + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return samples + diff --git a/comfy/sd.py b/comfy/sd.py index cb92802e9..2df340739 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae +import comfy.pixel_space_convert import yaml import math import os @@ -516,6 +517,15 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "pixel_space_vae" in sd: + self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE() + self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.downscale_ratio = 1 + self.upscale_ratio = 1 + self.latent_channels = 3 + self.latent_dim = 2 + self.output_channels = 3 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -785,65 +795,6 @@ class VAE: except: return None -# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 -# to LATENT B, C, H, W and values on the scale of -1..1. -class PixelspaceConversionVAE: - def __init__(self, size_increment: int=16): - self.intermediate_device = comfy.model_management.intermediate_device() - self.size_increment = size_increment - - def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: - if self.size_increment == 1: - return pixels - dims = pixels.shape[1:-1] - for d in range(len(dims)): - d_adj = (dims[d] // self.size_increment) * self.size_increment - if d_adj == d: - continue - d_offset = (dims[d] % self.size_increment) // 2 - pixels = pixels.narrow(d + 1, d_offset, d_adj) - return pixels - - def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - if pixels.ndim == 3: - pixels = pixels.unsqueeze(0) - elif pixels.ndim != 4: - raise ValueError("Unexpected input image shape") - # Ensure the image has spatial dimensions that are multiples of 16. - pixels = self.vae_encode_crop_pixels(pixels) - h, w, c = pixels.shape[1:] - if h < self.size_increment or w < self.size_increment: - raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") - pixels= pixels[..., :3] - if c == 1: - pixels = pixels.expand(-1, -1, -1, 3) - elif c != 3: - raise ValueError("Unexpected number of channels in input image") - # Rescale to -1..1 and move the channel dimension to position 1. - latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() - latent -= 0.5 - latent *= 2 - return latent.clamp_(-1, 1) - - def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - # Rescale to 0..1 and move the channel dimension to the end. - img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - img = img.clamp_(-1, 1).movedim(1, -1).contiguous() - img += 1.0 - img *= 0.5 - return img.clamp_(0, 1) - - encode_tiled = encode - decode_tiled = decode - - @classmethod - def spacial_compression_decode(cls) -> int: - # This just exists so the tiled VAE nodes don't crash. - return 1 - - spacial_compression_encode = spacial_compression_decode - temporal_compression_decode = spacial_compression_decode class StyleModel: def __init__(self, model, device="cpu"): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index be36b5dfe..557902d11 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma): latent_format = comfy.latent_formats.ChromaRadiance # Pixel-space model, no spatial compression for model input. - memory_usage_factor = 0.0325 + memory_usage_factor = 0.038 def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device) diff --git a/nodes.py b/nodes.py index 76b8cbac8..5a5fdcb8e 100644 --- a/nodes.py +++ b/nodes.py @@ -730,7 +730,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") - vaes.append("chroma_radiance") + vaes.append("pixel_space") return vaes @staticmethod @@ -773,8 +773,9 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name == "chroma_radiance": - return (comfy.sd.PixelspaceConversionVAE(),) + if vae_name == "pixel_space": + sd = {} + sd["pixel_space_vae"] = torch.tensor(1.0) elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: From f228367c5e3906de194968fa9b6fbe7aa9987bfa Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 13 Sep 2025 18:34:21 -0700 Subject: [PATCH 210/325] Make ModuleNotFoundError ImportError instead (#9850) --- comfy/ldm/modules/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index bf2553c37..9dd1a43c1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -22,7 +22,7 @@ SAGE_ATTENTION_IS_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTENTION_IS_AVAILABLE = True -except ModuleNotFoundError as e: +except ImportError as e: if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") @@ -34,7 +34,7 @@ FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func FLASH_ATTENTION_IS_AVAILABLE = True -except ModuleNotFoundError: +except ImportError: if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) From 4f1f26ac6c11b803bbc83cb347178e2f9b5e421b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:05:38 -0700 Subject: [PATCH 211/325] Add that hunyuan image is supported to readme. (#9857) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8024870c2..3f6cfc2ed 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) + - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From 47a9cde5d3045c42f20baafb9855fb96959124f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:10:55 -0700 Subject: [PATCH 212/325] Support the omnigen2 umo lora. (#9886) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 4a44f1318..36d26293a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -297,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.Omnigen2): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.QwenImage): for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format From 1a85483da159f2800407ae5a8a45eb0d88ffce2d Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:05:03 -0600 Subject: [PATCH 213/325] Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884) Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree --- comfy/k_diffusion/sampling.py | 35 +++++++++++++++++----------------- comfy/ldm/modules/attention.py | 3 ++- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2d7e09838..0e2cda291 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -86,24 +86,24 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): - self.cpu_tree = True - if "cpu" in kwargs: - self.cpu_tree = kwargs.pop("cpu") + self.cpu_tree = kwargs.pop("cpu", True) t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.get('w0', torch.zeros_like(x)) + w0 = kwargs.pop('w0', None) + if w0 is None: + w0 = torch.zeros_like(x) + self.batched = False if seed is None: - seed = torch.randint(0, 2 ** 63 - 1, []).item() - self.batched = True - try: - assert len(seed) == x.shape[0] + seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),) + elif isinstance(seed, (tuple, list)): + if len(seed) != x.shape[0]: + raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.") + self.batched = True w0 = w0[0] - except TypeError: - seed = [seed] - self.batched = False - if self.cpu_tree: - self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] else: - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + seed = (seed,) + if self.cpu_tree: + t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu() + self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed) @staticmethod def sort(a, b): @@ -111,11 +111,10 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) + device, dtype = t0.device, t0.dtype if self.cpu_tree: - w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) - else: - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) - + t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float() + w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign) return w if self.batched else w[0] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9dd1a43c1..7437e0567 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape mask = mask.unsqueeze(1) try: - assert mask is None + if mask is not None: + raise RuntimeError("Mask must not be set for Flash attention") out = flash_attn_wrapper( q.transpose(1, 2), k.transpose(1, 2), From a39ac59c3e3fddc8b278899814f0bd5371abb11f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:19:50 -0700 Subject: [PATCH 214/325] Add encoder part of whisper large v3 as an audio encoder model. (#9894) Not useful yet but some models use it. --- comfy/audio_encoders/audio_encoders.py | 58 +++++--- comfy/audio_encoders/whisper.py | 186 +++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 20 deletions(-) create mode 100755 comfy/audio_encoders/whisper.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 6fb5b08e9..0550b2f9b 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -1,4 +1,5 @@ from .wav2vec2 import Wav2Vec2Model +from .whisper import WhisperLargeV3 import comfy.model_management import comfy.ops import comfy.utils @@ -11,13 +12,18 @@ class AudioEncoderModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + model_type = config.pop("model_type") model_config = dict(config) model_config.update({ "dtype": self.dtype, "device": offload_device, "operations": comfy.ops.manual_cast }) - self.model = Wav2Vec2Model(**model_config) + + if model_type == "wav2vec2": + self.model = Wav2Vec2Model(**model_config) + elif model_type == "whisper3": + self.model = WhisperLargeV3(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -40,33 +46,45 @@ class AudioEncoderModel(): def load_audio_encoder_from_sd(sd, prefix=""): sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) - embed_dim = sd["encoder.layer_norm.bias"].shape[0] - if embed_dim == 1024:# large - config = { - "embed_dim": 1024, - "num_heads": 16, - "num_layers": 24, - "conv_norm": True, - "conv_bias": True, - "do_normalize": True, - "do_stable_layer_norm": True + if "encoder.layer_norm.bias" in sd: #wav2vec2 + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "model_type": "wav2vec2", + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "model_type": "wav2vec2", + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False } - elif embed_dim == 768: # base + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + elif "model.encoder.embed_positions.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) config = { - "embed_dim": 768, - "num_heads": 12, - "num_layers": 12, - "conv_norm": False, - "conv_bias": False, - "do_normalize": False, # chinese-wav2vec2-base has this False - "do_stable_layer_norm": False + "model_type": "whisper3", } else: - raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + raise RuntimeError("ERROR: audio encoder not supported.") audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) + if len(u) > 0: + logging.warning("unexpected audio encoder: {}".format(u)) return audio_encoder diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py new file mode 100755 index 000000000..93d3782f1 --- /dev/null +++ b/comfy/audio_encoders/whisper.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from typing import Optional +from comfy.ldm.modules.attention import optimized_attention_masked +import comfy.ops + +class WhisperFeatureExtractor(nn.Module): + def __init__(self, n_mels=128, device=None): + super().__init__() + self.sample_rate = 16000 + self.n_fft = 400 + self.hop_length = 160 + self.n_mels = n_mels + self.chunk_length = 30 + self.n_samples = 480000 + + self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sample_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=0, + f_max=8000, + norm="slaney", + mel_scale="slaney", + ).to(device) + + def __call__(self, audio): + audio = torch.mean(audio, dim=1) + batch_size = audio.shape[0] + processed_audio = [] + + for i in range(batch_size): + aud = audio[i] + if aud.shape[0] > self.n_samples: + aud = aud[:self.n_samples] + elif aud.shape[0] < self.n_samples: + aud = F.pad(aud, (0, self.n_samples - aud.shape[0])) + processed_audio.append(aud) + + audio = torch.stack(processed_audio) + + mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device) + + log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) + log_mel_spec = (log_mel_spec + 4.0) / 4.0 + + return log_mel_spec + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None): + super().__init__() + assert d_model % n_heads == 0 + + self.d_model = d_model + self.n_heads = n_heads + self.d_k = d_model // n_heads + + self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device) + self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = query.shape + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class EncoderLayer(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None): + super().__init__() + + self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations) + self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device) + self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device) + self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x, x, attention_mask) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = residual + x + + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 1500, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device) + self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device) + + self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device) + + self.layers = nn.ModuleList([ + EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations) + for _ in range(n_layer) + ]) + + self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + + x = x.transpose(1, 2) + + x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x) + + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x) + + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class WhisperLargeV3(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_audio_ctx: int = 1500, + n_audio_state: int = 1280, + n_audio_head: int = 20, + n_audio_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device) + + self.encoder = AudioEncoder( + n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, audio): + mel = self.feature_extractor(audio) + x, all_x = self.encoder(mel) + return x, all_x From e42682b24ef033a93001ba27cc5c5aa461a61d8d Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:21:14 +1000 Subject: [PATCH 215/325] Reduce Peak WAN inference VRAM usage (#9898) * flux: Do the xq and xk ropes one at a time This was doing independendent interleaved tensor math on the q and k tensors, leading to the holding of more than the minimum intermediates in VRAM. On a bad day, it would VRAM OOM on xk intermediates. Do everything q and then everything k, so torch can garbage collect all of qs intermediates before k allocates its intermediates. This reduces peak VRAM usage for some WAN2.2 inferences (at least). * wan: Optimize qkv intermediates on attention As commented. The former logic computed independent pieces of QKV in parallel which help more inference intermediates in VRAM spiking VRAM usage. Fully roping Q and garbage collecting the intermediates before touching K reduces the peak inference VRAM usage. --- comfy/ldm/flux/math.py | 11 +++++------ comfy/ldm/wan/model.py | 22 +++++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 4d743cda2..fb7cd7586 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 63472ada2..67dcf8f1e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -60,20 +60,24 @@ class WanSelfAttention(nn.Module): """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - # query, key, value function - def qkv_fn(x): + def qkv_fn_q(x): q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n * d) - return q, k, v + return apply_rope1(q, freqs) - q, k, v = qkv_fn(x) - q, k = apply_rope(q, k, freqs) + def qkv_fn_k(x): + k = self.norm_k(self.k(x)).view(b, s, n, d) + return apply_rope1(k, freqs) + + #These two are VRAM hogs, so we want to do all of q computation and + #have pytorch garbage collect the intermediates on the sub function + #return before we touch k + q = qkv_fn_q(x) + k = qkv_fn_k(x) x = optimized_attention( q.view(b, s, n * d), k.view(b, s, n * d), - v, + self.v(x).view(b, s, n * d), heads=self.num_heads, transformer_options=transformer_options, ) From 9288c78fc5fae74d3fa7787736dea442e996303f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 16 Sep 2025 21:12:48 -0700 Subject: [PATCH 216/325] Support the HuMo model. (#9903) --- comfy/audio_encoders/audio_encoders.py | 1 + comfy/ldm/wan/model.py | 259 ++++++++++++++++++++++++- comfy/model_base.py | 17 ++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 +- comfy_extras/nodes_wan.py | 98 ++++++++++ 6 files changed, 383 insertions(+), 6 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 0550b2f9b..46ef21c95 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -41,6 +41,7 @@ class AudioEncoderModel(): outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers + outputs["audio_samples"] = audio.shape[2] return outputs diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 67dcf8f1e..b3b7da5d5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module): num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6, operation_settings={}): + eps=1e-6, + kv_dim=None, + operation_settings={}): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -43,11 +45,13 @@ class WanSelfAttention(nn.Module): self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + if kv_dim is None: + kv_dim = dim # layers self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() @@ -402,6 +406,7 @@ class WanModel(torch.nn.Module): eps=1e-6, flf_pos_embed_token_number=None, in_dim_ref_conv=None, + wan_attn_block_class=WanAttentionBlock, image_model=None, device=None, dtype=None, @@ -479,8 +484,8 @@ class WanModel(torch.nn.Module): # blocks cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) for _ in range(num_layers) ]) @@ -1325,3 +1330,247 @@ class WanModel_S2V(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + + +class WanT2VCrossAttentionGather(WanSelfAttention): + + def forward(self, x, context, transformer_options={}, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] - video tokens + context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(context)) + v = self.v(context) + + # Handle audio temporal structure (16 tokens per frame) + k = k.reshape(-1, 16, n, d).transpose(1, 2) + v = v.reshape(-1, 16, n, d).transpose(1, 2) + + # Handle video spatial structure + q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2) + + x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + + x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) + x = self.o(x) + return x + + +class AudioCrossAttentionWrapper(nn.Module): + def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): + super().__init__() + + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings) + self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x, audio, transformer_options={}): + x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options) + return x + + +class WanAttentionBlockAudio(WanAttentionBlock): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings) + self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings) + + def forward( + self, + x, + e, + freqs, + context, + context_img_len=257, + audio=None, + transformer_options={}, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + # assert e.dtype == torch.float32 + + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + # assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, transformer_options=transformer_options) + + x = torch.addcmul(x, y, repeat_e(e[2], x)) + + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if audio is not None: + x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device)) + + self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device)) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class HumoWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='humo', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + audio_token_num=16, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + freqs=None, + audio_embed=None, + reference_latent=None, + transformer_options={}, + **kwargs, + ): + bs, _, time, height, width = x.shape + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype) + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + del ref, freqs_ref + + # context + context = self.text_embedding(context) + context_img_len = None + + if audio_embed is not None: + audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) + else: + audio = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 252dfcf69..cf99035da 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1213,6 +1213,23 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN21_HuMo(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + return out + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 03d44f65e..72621bed6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "camera_2.2" elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "s2v" + elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "humo" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 557902d11..213b5b92c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1073,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN21_HuMo(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "humo", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_HuMo(self, image_to_video=False, device=device) + return out + class WAN22_S2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1351,6 +1361,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4f73369f5..0b8b55813 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1015,6 +1015,103 @@ class WanSoundImageToVideoExtend(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent) +def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + +class WanHuMoImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanHuMoImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + else: + zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True) + + if audio_encoder_output is not None: + audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2) + audio_len = audio_encoder_output["audio_samples"] // 640 + audio_emb = audio_emb[:, :audio_len * 2] + + feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25) + audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) + + # pad for ref latent + zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype) + audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + + audio_emb = audio_emb.unsqueeze(0) + audio_emb_neg = torch.zeros_like(audio_emb) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg}) + else: + zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -1075,6 +1172,7 @@ class WanExtension(ComfyExtension): WanPhantomSubjectToVideo, WanSoundImageToVideo, WanSoundImageToVideoExtend, + WanHuMoImageToVideo, Wan22ImageToVideoLatent, ] From dd611a7700956f45f393dee32fb8505de176dc66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:39:24 -0700 Subject: [PATCH 217/325] Support the HuMo 17B model. (#9912) --- comfy/ldm/wan/model.py | 2 +- comfy/model_base.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b3b7da5d5..9cf3c171d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1364,7 +1364,7 @@ class AudioCrossAttentionWrapper(nn.Module): def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): super().__init__() - self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings) + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings) self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, x, audio, transformer_options={}): diff --git a/comfy/model_base.py b/comfy/model_base.py index cf99035da..70b67b7c1 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1220,14 +1220,37 @@ class WAN21_HuMo(WAN21): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) audio_embed = kwargs.get("audio_embed", None) if audio_embed is not None: out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) - reference_latents = kwargs.get("reference_latents", None) - if reference_latents is not None: - out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + if "c_concat" not in out: # 1.7B model + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + else: + noise_shape = list(noise.shape) + noise_shape[1] += 4 + concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) + zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) + zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) + concat_latent[:, 4:] = zero_vae_values + concat_latent[:, 4:, :1] = zero_vae_values_first + concat_latent[:, 4:, 1:2] = zero_vae_values_second + out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_latent_shape = list(ref_latent.shape) + ref_latent_shape[1] += 4 + ref_latent_shape[1] + ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype) + ref_latent_full[:, 20:] = ref_latent + ref_latent_full[:, 16:20] = 1.0 + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full) + return out class WAN22_S2V(WAN21): From 8d6653fca676a08df3e11654672fed92a183d147 Mon Sep 17 00:00:00 2001 From: DELUXA Date: Fri, 19 Sep 2025 02:50:37 +0300 Subject: [PATCH 218/325] Enable fp8 ops by default on gfx1200 (#9926) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bbfc3c7a1..d880f1970 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -348,7 +348,7 @@ try: # if any((a in arch) for a in ["gfx1201"]): # ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): - if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True except: From 1ea8c540640913b247248e46c907fb9b92a9dd4b Mon Sep 17 00:00:00 2001 From: Jodh Singh Date: Thu, 18 Sep 2025 19:51:16 -0400 Subject: [PATCH 219/325] make kernel of same type as image to avoid mismatch issues (#9932) --- comfy_extras/nodes_post_processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index cb1a0d883..ed7a07152 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -233,6 +233,7 @@ class Sharpen: kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) + kernel = kernel.to(dtype=image.dtype) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) From 24b0fce099c56d18ceb1f4f6b9455fee55e154ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:54:16 -0700 Subject: [PATCH 220/325] Do padding of audio embed in model for humo for more flexibility. (#9935) --- comfy/ldm/wan/model.py | 3 +++ comfy_extras/nodes_wan.py | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9cf3c171d..2dac5980c 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1551,6 +1551,9 @@ class HumoWanModel(WanModel): context_img_len = None if audio_embed is not None: + if reference_latent is not None: + zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype) + audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1) audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) else: audio = None diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0b8b55813..5f10edcff 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1095,10 +1095,6 @@ class WanHuMoImageToVideo(io.ComfyNode): audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) - # pad for ref latent - zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype) - audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) - audio_emb = audio_emb.unsqueeze(0) audio_emb_neg = torch.zeros_like(audio_emb) positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) From 711bcf33ee505a997674f4a9125e69d2a5a3c180 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 19 Sep 2025 00:03:30 -0700 Subject: [PATCH 221/325] Bump frontend to 1.26.13 (#9933) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index de5af5fac..79187efaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.26.11 +comfyui-frontend-package==1.26.13 comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch From dc95b6acc0ef4962460592d417db4024f7160586 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Sep 2025 00:07:17 -0700 Subject: [PATCH 222/325] Basic WIP support for the wan animate model. (#9939) --- comfy/ldm/wan/model_animate.py | 548 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 18 ++ comfy/model_detection.py | 2 + comfy/supported_models.py | 15 +- comfy_extras/nodes_wan.py | 84 +++++ 5 files changed, 666 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/wan/model_animate.py diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py new file mode 100644 index 000000000..542f54110 --- /dev/null +++ b/comfy/ldm/wan/model_animate.py @@ -0,0 +1,548 @@ +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from .model import WanModel, sinusoidal_embedding_1d +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + + self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs) + self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + +def get_norm_layer(norm_layer, operations=None): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return operations.LayerNorm + elif norm_layer == "rms": + return operations.RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, device=None, operations=None + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + operations=operations, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations=None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + # use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp) + + attn = optimized_attention(q, k, v, heads=self.heads_num) + + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162 +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1) + return out[:, :, ::down_y, ::down_x] + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81 +class FusedLeakyReLU(torch.nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None): + super().__init__() + self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + +class Blur(torch.nn.Module): + def __init__(self, kernel, pad, dtype=None, device=None): + super().__init__() + kernel = torch.tensor(kernel, dtype=dtype, device=device) + kernel = kernel[None, :] * kernel[:, None] + kernel = kernel / kernel.sum() + self.register_buffer('kernel', kernel) + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad) + +#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590 +class ScaledLeakyReLU(torch.nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605 +class EqualConv2d(torch.nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + self.stride = stride + self.padding = padding + self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) + + return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134 +class EqualLinear(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype)) + self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None + self.activation = activation + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul + + if self.activation: + out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale) + return fused_leaky_relu(out, bias) + return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654 +class ConvLayer(torch.nn.Sequential): + def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2))) + stride, padding = 2, 0 + else: + stride, padding = 1, kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations)) + + if activate: + layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704 +class ResBlock(torch.nn.Module): + def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None): + super().__init__() + self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations) + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations) + + def forward(self, input): + out = self.conv2(self.conv1(input)) + skip = self.skip(input) + return (out + skip) / math.sqrt(2) + + +class EncoderApp(torch.nn.Module): + def __init__(self, w_dim=512, dtype=None, device=None, operations=None): + super().__init__() + kwargs = {"device": device, "dtype": dtype, "operations": operations} + + self.convs = torch.nn.ModuleList([ + ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs), + ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs), + ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs), + ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs), + EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs) + ]) + + def forward(self, x): + h = x + for conv in self.convs: + h = conv(h) + return h.squeeze(-1).squeeze(-1) + +class Encoder(torch.nn.Module): + def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations) + self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)]) + + def encode_motion(self, x): + return self.fc(self.net_app(x)) + +class Direction(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype)) + self.motion_dim = motion_dim + + def forward(self, input): + stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype) + Q, _ = torch.linalg.qr(stabilized_weight.float()) + if input is None: + return Q + return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1) + +class Synthesis(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations) + +class Generator(torch.nn.Module): + def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations) + self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations) + + def get_motion(self, img): + motion_feat = self.enc.encode_motion(img) + return self.dec.direction(motion_feat) + +class AnimateWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='animate', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + motion_encoder_dim=512, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.pose_patch_embedding = operations.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations) + + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + device=device, dtype=dtype, operations=operations + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + device=device, dtype=dtype, operations=operations + ) + + def after_patch_embedding(self, x, pose_latents, face_pixel_values): + if pose_latents is not None: + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + if face_pixel_values is None: + return x, None + + b, c, T, h, w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + if motion_vec.shape[1] < x.shape[2]: + B, L, H, C = motion_vec.shape + pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec) + motion_vec = torch.cat([motion_vec, pad], dim=1) + else: + motion_vec = motion_vec[:, :x.shape[2]] + return x, motion_vec + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + pose_latents=None, + face_pixel_values=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + if i % 5 == 0 and motion_vec is not None: + x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec) + + # head + x = self.head(x, e) + + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 70b67b7c1..b0b9cde7d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -39,6 +39,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.wan.model_animate import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1253,6 +1254,23 @@ class WAN21_HuMo(WAN21): return out +class WAN22_Animate(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + face_video_pixels = kwargs.get("face_video_pixels", None) + if face_video_pixels is not None: + out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) + return out + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 72621bed6..46415c17a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -404,6 +404,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "s2v" elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "humo" + elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "animate" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 213b5b92c..1fbb6aef4 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1096,6 +1096,19 @@ class WAN22_S2V(WAN21_T2V): out = model_base.WAN22_S2V(self, device=device) return out +class WAN22_Animate(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "animate", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_Animate(self, device=device) + return out + class WAN22_T2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1361,6 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 5f10edcff..4187a5619 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1108,6 +1108,89 @@ class WanHuMoImageToVideo(io.ComfyNode): out_latent["samples"] = latent return io.NodeOutput(positive, negative, out_latent) +class WanAnimateToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanAnimateToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("face_video", optional=True), + io.Image.Input("pose_video", optional=True), + io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Image.Input("continue_motion", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + latent_length = ((length - 1) // 4) + 1 + latent_width = width // 8 + latent_height = height // 8 + trim_latent = 0 + + if reference_image is None: + reference_image = torch.zeros((1, height, width, 3)) + + image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) + trim_latent += concat_latent_image.shape[2] + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if face_video is not None: + face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 + face_video = face_video.movedim(0, 1).unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) + negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) + negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) + + if continue_motion is None: + image = torch.ones((length, height, width, 3)) * 0.5 + else: + continue_motion = continue_motion[-continue_motion_max_frames:] + continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 + image[:continue_motion.shape[0]] = continue_motion + + concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) + mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) + if continue_motion is not None: + mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0 + + mask = torch.cat((mask, mask_refmotion), dim=2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, trim_latent) + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -1169,6 +1252,7 @@ class WanExtension(ComfyExtension): WanSoundImageToVideo, WanSoundImageToVideoExtend, WanHuMoImageToVideo, + WanAnimateToVideo, Wan22ImageToVideoLatent, ] From 9fdf8c25abb2133803063a9be395cac774fce611 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:02:43 +0300 Subject: [PATCH 223/325] api_nodes: reduce default timeout from 7 days to 2 hours (#9918) --- comfy_api_nodes/apis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 4ad0b783b..0aed906fb 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -683,7 +683,7 @@ class SynchronousOperation(Generic[T, R]): auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, auth_kwargs: Optional[Dict[str, str]] = None, - timeout: float = 604800.0, + timeout: float = 7200.0, verify_ssl: bool = True, content_type: str = "application/json", multipart_parser: Callable | None = None, From 852704c81a652cc53fbe53c5f47dea0e50d0534e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:04:51 +0300 Subject: [PATCH 224/325] fix(seedream4): add flag to ignore error on partial success (#9952) --- comfy_api_nodes/nodes_bytedance.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 369a3a4fe..a7eeaf15a 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -567,6 +567,12 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): tooltip="Whether to add an \"AI generated\" watermark to the image.", optional=True, ), + comfy_io.Boolean.Input( + "fail_on_partial", + default=True, + tooltip="If enabled, abort execution if any requested images are missing or return an error.", + optional=True, + ), ], outputs=[ comfy_io.Image.Output(), @@ -592,6 +598,7 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): max_images: int = 1, seed: int = 0, watermark: bool = True, + fail_on_partial: bool = True, ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) w = h = None @@ -651,9 +658,10 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): if len(response.data) == 1: return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - return comfy_io.NodeOutput( - torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data]) - ) + urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] + if fail_on_partial and len(urls) < len(response.data): + raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.") + return comfy_io.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls])) class ByteDanceTextToVideoNode(comfy_io.ComfyNode): @@ -1171,7 +1179,7 @@ async def process_video_task( payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], auth_kwargs: dict, node_id: str, - estimated_duration: int | None, + estimated_duration: Optional[int], ) -> comfy_io.NodeOutput: initial_response = await SynchronousOperation( endpoint=ApiEndpoint( From e8df53b764c7dfce1a9235f6ee70a17cfdece3ff Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:48:56 -0700 Subject: [PATCH 225/325] Update WanAnimateToVideo to more easily extend videos. (#9959) --- comfy/ldm/wan/model_animate.py | 2 +- comfy_extras/nodes_wan.py | 63 +++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 542f54110..7c87835d4 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -451,7 +451,7 @@ class AnimateWanModel(WanModel): def after_patch_embedding(self, x, pose_latents, face_pixel_values): if pose_latents is not None: pose_latents = self.pose_patch_embedding(pose_latents) - x[:, :, 1:] += pose_latents + x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1] if face_pixel_values is None: return x, None diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4187a5619..3e5fef535 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1128,18 +1128,22 @@ class WanAnimateToVideo(io.ComfyNode): io.Image.Input("pose_video", optional=True), io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Image.Input("continue_motion", optional=True), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), ], outputs=[ io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), io.Int.Output(display_name="trim_latent"), + io.Int.Output(display_name="trim_image"), + io.Int.Output(display_name="video_frame_offset"), ], is_experimental=True, ) @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + trim_to_pose_video = False latent_length = ((length - 1) // 4) + 1 latent_width = width // 8 latent_height = height // 8 @@ -1152,35 +1156,60 @@ class WanAnimateToVideo(io.ComfyNode): concat_latent_image = vae.encode(image[:, :, :, :3]) mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) trim_latent += concat_latent_image.shape[2] + ref_motion_latent_length = 0 + + if continue_motion is None: + image = torch.ones((length, height, width, 3)) * 0.5 + else: + continue_motion = continue_motion[-continue_motion_max_frames:] + video_frame_offset -= continue_motion.shape[0] + video_frame_offset = max(0, video_frame_offset) + continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 + image[:continue_motion.shape[0]] = continue_motion + ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1 if clip_vision_output is not None: positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if not trim_to_pose_video: + if pose_video.shape[0] < length: + pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0) + + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) + negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) + + if trim_to_pose_video: + latent_length = pose_video_latent.shape[2] + length = latent_length * 4 - 3 + image = image[:length] + + if face_video is not None: + if face_video.shape[0] <= video_frame_offset: + face_video = None + else: + face_video = face_video[video_frame_offset:] + if face_video is not None: face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 face_video = face_video.movedim(0, 1).unsqueeze(0) positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) - if pose_video is not None: - pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) - pose_video_latent = vae.encode(pose_video[:, :, :, :3]) - positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) - negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) - - if continue_motion is None: - image = torch.ones((length, height, width, 3)) * 0.5 - else: - continue_motion = continue_motion[-continue_motion_max_frames:] - continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) - image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 - image[:continue_motion.shape[0]] = continue_motion - concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) if continue_motion is not None: - mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0 + mask_refmotion[:, :, :ref_motion_latent_length] = 0.0 mask = torch.cat((mask, mask_refmotion), dim=2) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) @@ -1189,7 +1218,7 @@ class WanAnimateToVideo(io.ComfyNode): latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length) class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod From 66241cef31f21247ec8b450d699250fd83b3ff7c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:24:10 -0700 Subject: [PATCH 226/325] Add inputs for character replacement to the WanAnimateToVideo node. (#9960) --- comfy_extras/nodes_wan.py | 40 +++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 3e5fef535..9cca6fb2e 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1127,6 +1127,8 @@ class WanAnimateToVideo(io.ComfyNode): io.Image.Input("face_video", optional=True), io.Image.Input("pose_video", optional=True), io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Image.Input("background_video", optional=True), + io.Mask.Input("character_mask", optional=True), io.Image.Input("continue_motion", optional=True), io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), ], @@ -1142,7 +1144,7 @@ class WanAnimateToVideo(io.ComfyNode): ) @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput: trim_to_pose_video = False latent_length = ((length - 1) // 4) + 1 latent_width = width // 8 @@ -1154,7 +1156,7 @@ class WanAnimateToVideo(io.ComfyNode): image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) + mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) trim_latent += concat_latent_image.shape[2] ref_motion_latent_length = 0 @@ -1206,11 +1208,37 @@ class WanAnimateToVideo(io.ComfyNode): positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) - concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) - mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) - if continue_motion is not None: - mask_refmotion[:, :, :ref_motion_latent_length] = 0.0 + ref_images_num = max(0, ref_motion_latent_length * 4 - 3) + if background_video is not None: + if background_video.shape[0] > video_frame_offset: + background_video = background_video[video_frame_offset:] + background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if background_video.shape[0] > ref_images_num: + image[ref_images_num:background_video.shape[0] - ref_images_num] = background_video[ref_images_num:] + mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) + if continue_motion is not None: + mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0 + + if character_mask is not None: + if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1: + if character_mask.shape[0] == 1: + character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1)) + else: + character_mask = character_mask[video_frame_offset:] + if character_mask.ndim == 3: + character_mask = character_mask.unsqueeze(1) + character_mask = character_mask.movedim(0, 1) + if character_mask.ndim == 4: + character_mask = character_mask.unsqueeze(1) + character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center") + if character_mask.shape[2] > ref_images_num: + mask_refmotion[:, :, ref_images_num:character_mask.shape[2] + ref_images_num] = character_mask[:, :, ref_images_num:] + + concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) + + + mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2) mask = torch.cat((mask, mask_refmotion), dim=2) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) From 9ed3c5cc09c55d2fffa67b59d9d21e3b44d7653e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 20 Sep 2025 18:10:39 -0700 Subject: [PATCH 227/325] [Reviving #5709] Add strength input to Differential Diffusion (#9957) * Update nodes_differential_diffusion.py * Update nodes_differential_diffusion.py * Make strength optional to avoid validation errors when loading old workflows, adjust step --------- Co-authored-by: ThereforeGames --- comfy_extras/nodes_differential_diffusion.py | 33 +++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 98dbbf102..255ac420d 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -5,19 +5,30 @@ import torch class DifferentialDiffusion(): @classmethod def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - }} + return { + "required": { + "model": ("MODEL", ), + }, + "optional": { + "strength": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "apply" CATEGORY = "_for_testing" INIT = False - def apply(self, model): + def apply(self, model, strength=1.0): model = model.clone() - model.set_model_denoise_mask_function(self.forward) - return (model,) + model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) + return (model, ) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -31,7 +42,15 @@ class DifferentialDiffusion(): threshold = (current_ts - ts_to) / (ts_from - ts_to) - return (denoise_mask >= threshold).to(denoise_mask.dtype) + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask NODE_CLASS_MAPPINGS = { From 7be2b49b6b3430783555bc6bc8fcb3f46d5392e7 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 21 Sep 2025 09:24:48 +0800 Subject: [PATCH 228/325] Fix LoRA Trainer bugs with FP8 models. (#9854) * Fix adapter weight init * Fix fp8 model training * Avoid inference tensor --- comfy/ops.py | 13 +++++++------ comfy/weight_adapter/loha.py | 8 ++++---- comfy/weight_adapter/lokr.py | 4 ++-- comfy/weight_adapter/lora.py | 4 ++-- comfy/weight_adapter/oft.py | 2 +- comfy_extras/nodes_train.py | 18 ++++++++++++++++++ 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 55e958adb..9d7dedd37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -365,12 +365,13 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - try: - out = fp8_linear(self, input) - if out is not None: - return out - except Exception as e: - logging.info("Exception during fp8 op: {}".format(e)) + if not self.training: + try: + out = fp8_linear(self, input) + if out is not None: + return out + except Exception as e: + logging.info("Exception during fp8 op: {}".format(e)) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 55c97a3af..0abb2d403 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat1, 0.1) torch.nn.init.constant_(mat2, 0.0) - mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) return LohaDiff( diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 563c835f5..9b2aff2d7 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase): in_dim = weight.shape[1:].numel() out1, out2 = factorization(out_dim, rank) in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) return LokrDiff( diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 47aa17d13..4db004e50 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) return LoraDiff( diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 9d4982083..c0aab9635 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) return OFTDiff( (block, None, alpha, None) ) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index c3aaaee9b..9e6ec6780 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): return new_dict +def process_cond_list(d, prefix=""): + if hasattr(d, "__iter__") and not hasattr(d, "items"): + for index, item in enumerate(d): + process_cond_list(item, f"{prefix}.{index}") + return d + elif hasattr(d, "items"): + for k, v in list(d.items()): + if isinstance(v, dict): + process_cond_list(v, f"{prefix}.{k}") + elif isinstance(v, torch.Tensor): + d[k] = v.clone() + elif isinstance(v, (list, tuple)): + for index, item in enumerate(v): + process_cond_list(item, f"{prefix}.{k}.{index}") + return d + + class TrainSampler(comfy.samplers.Sampler): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn @@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler): self.training_dtype = training_dtype def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() From d1d9eb94b1096c9b3f963bf152bd6b9cd330c3a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 20 Sep 2025 19:09:35 -0700 Subject: [PATCH 229/325] Lower wan memory estimation value a bit. (#9964) Previous pr reduced the peak memory requirement. --- comfy/supported_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1fbb6aef4..4064bdae1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -995,7 +995,7 @@ class WAN21_T2V(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Wan21 - memory_usage_factor = 1.0 + memory_usage_factor = 0.9 supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -1004,7 +1004,7 @@ class WAN21_T2V(supported_models_base.BASE): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000 + self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222 def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, device=device) From 27bc181c49249f11da2d8a14f84f3bdb58a0615f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 21 Sep 2025 16:48:31 -0700 Subject: [PATCH 230/325] Set some wan nodes as no longer experimental. (#9976) --- comfy_extras/nodes_wan.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9cca6fb2e..b1e9babb5 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode): return io.Schema( node_id="WanVaceToVideo", category="conditioning/video_models", - is_experimental=True, inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -375,7 +374,6 @@ class TrimVideoLatent(io.ComfyNode): return io.Schema( node_id="TrimVideoLatent", category="latent/video", - is_experimental=True, inputs=[ io.Latent.Input("samples"), io.Int.Input("trim_amount", default=0, min=0, max=99999), @@ -969,7 +967,6 @@ class WanSoundImageToVideo(io.ComfyNode): io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), ], - is_experimental=True, ) @classmethod @@ -1000,7 +997,6 @@ class WanSoundImageToVideoExtend(io.ComfyNode): io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), ], - is_experimental=True, ) @classmethod From 1fee8827cb8160c85d96c375413ac590311525dc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:49:48 -0700 Subject: [PATCH 231/325] Support for qwen edit plus model. Use the new TextEncodeQwenImageEditPlus. (#9986) --- comfy/text_encoders/llama.py | 16 +++++++---- comfy_extras/nodes_qwen.py | 55 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 5e11956b5..c5a48ba9f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -400,21 +400,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): grid = None + position_ids = None + offset = 0 for e in embeds_info: if e.get("type") == "image": grid = e.get("extra", None) - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) start = e.get("index") - position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + if position_ids is None: + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) - position_ids[0, start:end] = start + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 - position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] max_d = int(grid[0][2]) // 2 - position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) if grid is None: position_ids = None diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index fff89556f..49747dc7a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -43,6 +43,61 @@ class TextEncodeQwenImageEdit: return (conditioning, ) +class TextEncodeQwenImageEditPlus: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }, + "optional": {"vae": ("VAE", ), + "image1": ("IMAGE", ), + "image2": ("IMAGE", ), + "image3": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None): + ref_latents = [] + images = [image1, image2, image3] + images_vl = [] + llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + image_prompt = "" + + for i, image in enumerate(images): + if image is not None: + samples = image.movedim(-1, 1) + total = int(384 * 384) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + images_vl.append(s.movedim(1, -1)) + if vae is not None: + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) + + image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1) + + tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + return (conditioning, ) + + NODE_CLASS_MAPPINGS = { "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, + "TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus, } From e3206351b07852f2127a56abd898ee77f7f4c25f Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 22 Sep 2025 14:12:32 -0700 Subject: [PATCH 232/325] add offset param (#9977) --- server.py | 9 ++- tests/execution/test_execution.py | 105 +++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 43816a8cd..603677397 100644 --- a/server.py +++ b/server.py @@ -645,7 +645,14 @@ class PromptServer(): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + + offset = request.rel_url.query.get("offset", None) + if offset is not None: + offset = int(offset) + else: + offset = -1 + + return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 8ea05fdd8..ef73ad9fd 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -84,6 +84,21 @@ class ComfyClient: with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: return json.loads(response.read()) + def get_all_history(self, max_items=None, offset=None): + url = "http://{}/history".format(self.server_address) + params = {} + if max_items is not None: + params["max_items"] = max_items + if offset is not None: + params["offset"] = offset + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + def set_test_name(self, name): self.test_name = name @@ -498,7 +513,6 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -762,3 +776,92 @@ class TestExecution: except urllib.error.HTTPError: pass # Expected behavior + def _create_history_item(self, client, builder): + g = GraphBuilder(prefix="offset_test") + input_node = g.node( + "StubImage", content="BLACK", height=32, width=32, batch_size=1 + ) + g.node("SaveImage", images=input_node.out(0)) + return client.run(g) + + def test_offset_returns_different_items_than_beginning_of_history( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that offset skips items at the beginning""" + for _ in range(5): + self._create_history_item(client, builder) + + first_two = client.get_all_history(max_items=2, offset=0) + next_two = client.get_all_history(max_items=2, offset=2) + + assert set(first_two.keys()).isdisjoint( + set(next_two.keys()) + ), "Offset should skip initial items" + + def test_offset_beyond_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset larger than total history returns empty result""" + self._create_history_item(client, builder) + + result = client.get_all_history(offset=100) + assert len(result) == 0, "Large offset should return no items" + + def test_offset_at_exact_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset equal to history length returns empty""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + result = client.get_all_history(offset=len(all_history)) + assert len(result) == 0, "Offset at history length should return empty" + + def test_offset_zero_equals_no_offset_parameter( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset=0 behaves same as omitting offset""" + self._create_history_item(client, builder) + + with_zero = client.get_all_history(offset=0) + without_offset = client.get_all_history() + + assert with_zero == without_offset, "offset=0 should equal no offset" + + def test_offset_without_max_items_skips_from_beginning( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset alone (no max_items) returns remaining items""" + for _ in range(4): + self._create_history_item(client, builder) + + all_items = client.get_all_history() + offset_items = client.get_all_history(offset=2) + + assert ( + len(offset_items) == len(all_items) - 2 + ), "Offset should skip specified number of items" + + def test_offset_with_max_items_returns_correct_window( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset + max_items returns correct slice of history""" + for _ in range(6): + self._create_history_item(client, builder) + + window = client.get_all_history(max_items=2, offset=1) + assert len(window) <= 2, "Should respect max_items limit" + + def test_offset_near_end_returns_remaining_items_only( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset near end of history returns only remaining items""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + # Offset to near the end + result = client.get_all_history(max_items=5, offset=len(all_history) - 1) + + assert len(result) <= 1, "Should return at most 1 item when offset is near end" From 8a5ac527e60fcd48ec228d309d49ab28ac79def8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:26:58 -0700 Subject: [PATCH 233/325] Fix bug with WanAnimateToVideo node. (#9988) --- comfy_extras/nodes_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index b1e9babb5..6c16a2673 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1210,7 +1210,7 @@ class WanAnimateToVideo(io.ComfyNode): background_video = background_video[video_frame_offset:] background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) if background_video.shape[0] > ref_images_num: - image[ref_images_num:background_video.shape[0] - ref_images_num] = background_video[ref_images_num:] + image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:] mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) if continue_motion is not None: From 707b2638ecd82360c0a67e1d86cc4fdeae218d03 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:34:33 -0700 Subject: [PATCH 234/325] Fix bug with WanAnimateToVideo. (#9990) --- comfy_extras/nodes_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 6c16a2673..b0bd471bf 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1229,7 +1229,7 @@ class WanAnimateToVideo(io.ComfyNode): character_mask = character_mask.unsqueeze(1) character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center") if character_mask.shape[2] > ref_images_num: - mask_refmotion[:, :, ref_images_num:character_mask.shape[2] + ref_images_num] = character_mask[:, :, ref_images_num:] + mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:] concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) From 145b0e4f79b5d9e815bb781ba29ccd057bb52dab Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 23 Sep 2025 23:22:35 +0800 Subject: [PATCH 235/325] update template to 0.1.86 (#9998) * update template to 0.1.84 * update template to 0.1.85 * Update template to 0.1.86 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 79187efaa..2980bebdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.26.13 -comfyui-workflow-templates==0.1.81 +comfyui-workflow-templates==0.1.86 comfyui-embedded-docs==0.2.6 torch torchsde From e8087907995497c6971ee64bd5fa02cb49c1eda6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 23 Sep 2025 18:36:47 +0300 Subject: [PATCH 236/325] feat(api-nodes): add wan t2i, t2v, i2v nodes (#9996) --- comfy_api_nodes/nodes_wan.py | 602 +++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 603 insertions(+) create mode 100644 comfy_api_nodes/nodes_wan.py diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py new file mode 100644 index 000000000..db5bd41c1 --- /dev/null +++ b/comfy_api_nodes/nodes_wan.py @@ -0,0 +1,602 @@ +import re +from typing import Optional, Type, Union +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field +from comfy_api.latest import ComfyExtension, Input, io as comfy_io +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, + R, + T, +) +from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration + +from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + tensor_to_base64_string, + audio_to_base64_string, +) + +class Text2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + + +class Text2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + audio_url: Optional[str] = Field(None) + + +class Image2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + img_url: str = Field(...) + audio_url: Optional[str] = Field(None) + + +class Txt2ImageParametersField(BaseModel): + size: str = Field(...) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + + +class Text2VideoParametersField(BaseModel): + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=10) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Should be audio generated automatically") + + +class Image2VideoParametersField(BaseModel): + resolution: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=10) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Should be audio generated automatically") + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2ImageInputField = Field(...) + parameters: Txt2ImageParametersField = Field(...) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2VideoInputField = Field(...) + parameters: Text2VideoParametersField = Field(...) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2VideoInputField = Field(...) + parameters: Image2VideoParametersField = Field(...) + + +class TaskCreationOutputField(BaseModel): + task_id: str = Field(...) + task_status: str = Field(...) + + +class TaskCreationResponse(BaseModel): + output: Optional[TaskCreationOutputField] = Field(None) + request_id: str = Field(...) + code: Optional[str] = Field(None, description="The error code of the failed request.") + message: Optional[str] = Field(None, description="Details of the failed request.") + + +class TaskResult(BaseModel): + url: Optional[str] = Field(None) + code: Optional[str] = Field(None) + message: Optional[str] = Field(None) + + +class ImageTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + results: Optional[list[TaskResult]] = Field(None) + + +class VideoTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + video_url: Optional[str] = Field(None) + code: Optional[str] = Field(None) + message: Optional[str] = Field(None) + + +class ImageTaskStatusResponse(BaseModel): + output: Optional[ImageTaskStatusOutputField] = Field(None) + request_id: str = Field(...) + + +class VideoTaskStatusResponse(BaseModel): + output: Optional[VideoTaskStatusOutputField] = Field(None) + request_id: str = Field(...) + + +RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') + + +async def process_task( + auth_kwargs: dict[str, str], + url: str, + request_model: Type[T], + response_model: Type[R], + payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + node_id: str, + estimated_duration: int, + poll_interval: int, +) -> Type[R]: + initial_response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=url, + method=HttpMethod.POST, + request_model=request_model, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + + return await PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=response_model, + ), + completed_statuses=["SUCCEEDED"], + failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], + status_extractor=lambda x: x.output.task_status, + estimated_duration=estimated_duration, + poll_interval=poll_interval, + node_id=node_id, + auth_kwargs=auth_kwargs, + ).execute() + + +class WanTextToImageApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanTextToImageApi", + display_name="Wan Text to Image", + category="api node/image/Wan", + description="Generates image based on text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-t2i-preview"], + default="wan2.5-t2i-preview", + tooltip="Model to use.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Int.Input( + "width", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + comfy_io.Int.Input( + "height", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + width: int = 1024, + height: int = 1024, + seed: int = 0, + prompt_extend: bool = True, + watermark: bool = True, + ): + payload = Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", + request_model=Text2ImageTaskCreationRequest, + response_model=ImageTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=9, + poll_interval=3, + ) + return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + +class WanTextToVideoApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanTextToVideoApi", + display_name="Wan Text to Video", + category="api node/video/Wan", + description="Generates video based on text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-t2v-preview"], + default="wan2.5-t2v-preview", + tooltip="Model to use.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Combo.Input( + "size", + options=[ + "480p: 1:1 (624x624)", + "480p: 16:9 (832x480)", + "480p: 9:16 (480x832)", + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + default="480p: 1:1 (624x624)", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Available durations: 5 and 10 seconds", + optional=True, + ), + comfy_io.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If there is no audio input, generate audio automatically.", + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + size: str = "480p: 1:1 (624x624)", + duration: int = 5, + audio: Optional[Input.Audio] = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + ): + width, height = RES_IN_PARENS.search(size).groups() + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + payload = Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + request_model=Text2VideoTaskCreationRequest, + response_model=VideoTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanImageToVideoApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanImageToVideoApi", + display_name="Wan Image to Video", + category="api node/video/Wan", + description="Generates video based on the first frame and text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-i2v-preview"], + default="wan2.5-i2v-preview", + tooltip="Model to use.", + ), + comfy_io.Image.Input( + "image", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[ + "480P", + "720P", + "1080P", + ], + default="480P", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Available durations: 5 and 10 seconds", + optional=True, + ), + comfy_io.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If there is no audio input, generate audio automatically.", + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + negative_prompt: str = "", + resolution: str = "480P", + duration: int = 5, + audio: Optional[Input.Audio] = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + ): + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + payload = Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + request_model=Image2VideoTaskCreationRequest, + response_model=VideoTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + WanTextToImageApi, + WanTextToVideoApi, + WanImageToVideoApi, + ] + + +async def comfy_entrypoint() -> WanApiExtension: + return WanApiExtension() diff --git a/nodes.py b/nodes.py index 5a5fdcb8e..1a6784b68 100644 --- a/nodes.py +++ b/nodes.py @@ -2361,6 +2361,7 @@ async def init_builtin_api_nodes(): "nodes_rodin.py", "nodes_gemini.py", "nodes_vidu.py", + "nodes_wan.py", ] if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): From b8730510db30c8858e1e5d8e126ef19eac395560 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 Sep 2025 11:50:33 -0400 Subject: [PATCH 237/325] ComfyUI version 0.3.60 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index ee58205f5..d469a8194 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.59" +__version__ = "0.3.60" diff --git a/pyproject.toml b/pyproject.toml index a7fc1a5a6..7340c320b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.59" +version = "0.3.60" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 341b4adefd308cbcf82c07effc255f2770b3b3e2 Mon Sep 17 00:00:00 2001 From: Changrz <51637999+WhiteGiven@users.noreply.github.com> Date: Thu, 25 Sep 2025 02:05:37 +0800 Subject: [PATCH 238/325] Rodin3D - add [Rodin3D Gen-2 generate] api-node (#9994) * update Rodin api node * update rodin3d gen2 api node * fix images limited bug --- comfy_api_nodes/apis/rodin_api.py | 3 +- comfy_api_nodes/nodes_rodin.py | 140 ++++++++++++++++++++++++------ 2 files changed, 117 insertions(+), 26 deletions(-) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py index b0cf171fa..02cf42c29 100644 --- a/comfy_api_nodes/apis/rodin_api.py +++ b/comfy_api_nodes/apis/rodin_api.py @@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel): seed: int = Field(..., description="seed_") tier: str = Field(..., description="Tier of generation.") material: str = Field(..., description="The material type.") - quality: str = Field(..., description="The generation quality of the mesh.") + quality_override: int = Field(..., description="The poly count of the mesh.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") + TAPose: Optional[bool] = Field(None, description="") class GenerateJobsData(BaseModel): uuids: List[str] = Field(..., description="str LIST") diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index c89d087e5..1af393eba 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -121,10 +121,10 @@ class Rodin3DAPI: else: return "Generating" - async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) >= 5: + if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") path = "/proxy/rodin/api/v2/rodin" @@ -139,8 +139,9 @@ class Rodin3DAPI: seed=seed, tier=tier, material=material, - quality=quality, - mesh_mode=mesh_mode + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=TAPose, ), files=[ ( @@ -211,23 +212,36 @@ class Rodin3DAPI: return await operation.execute() def get_quality_mode(self, poly_count): - if poly_count == "200K-Triangle": + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": mesh_mode = "Raw" - quality = "medium" + elif poly == "Quad": + mesh_mode = "Quad" else: mesh_mode = "Quad" - if poly_count == "4K-Quad": - quality = "extra-low" - elif poly_count == "8K-Quad": - quality = "low" - elif poly_count == "18K-Quad": - quality = "medium" - elif poly_count == "50K-Quad": - quality = "high" - else: - quality = "medium" - return mesh_mode, quality + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override async def download_files(self, url_list): save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) @@ -300,9 +314,9 @@ class Rodin3D_Regular(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -346,9 +360,9 @@ class Rodin3D_Detail(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -392,9 +406,9 @@ class Rodin3D_Smooth(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -446,10 +460,10 @@ class Rodin3D_Sketch(Rodin3DAPI): for i in range(num_images): m_images.append(Images[i]) material_type = "PBR" - quality = "medium" + quality_override = 18000 mesh_mode = "Quad" task_uuid, subscription_key = await self.create_generate_task( - images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs + images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs ) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -457,6 +471,80 @@ class Rodin3D_Sketch(Rodin3DAPI): return (model,) +class Rodin3D_Gen2(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + "Seed": ( + IO.INT, + { + "default":0, + "min":0, + "max":65535, + "display":"number" + } + ), + "Material_Type": ( + IO.COMBO, + { + "options": ["PBR", "Shaded"], + "default": "PBR" + } + ), + "Polygon_count": ( + IO.COMBO, + { + "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + "default": "500K-Triangle" + } + ), + "TAPose": ( + IO.BOOLEAN, + { + "default": False, + } + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + async def api_call( + self, + Images, + Seed, + Material_Type, + Polygon_count, + TAPose, + **kwargs + ): + tier = "Gen-2" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) + + return (model,) + # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique NODE_CLASS_MAPPINGS = { @@ -464,6 +552,7 @@ NODE_CLASS_MAPPINGS = { "Rodin3D_Detail": Rodin3D_Detail, "Rodin3D_Smooth": Rodin3D_Smooth, "Rodin3D_Sketch": Rodin3D_Sketch, + "Rodin3D_Gen2": Rodin3D_Gen2, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -472,4 +561,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", + "Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", } From fd79d32f38fd24adca5a6e8214f05050f287c9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 25 Sep 2025 01:59:29 +0300 Subject: [PATCH 239/325] Add new audio nodes (#9908) * Add new audio nodes - TrimAudioDuration - SplitAudioChannels - AudioConcat - AudioMerge - AudioAdjustVolume * Update nodes_audio.py * Add EmptyAudio -node * Change duration to Float (allows sub seconds) --- comfy_extras/nodes_audio.py | 223 ++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 3b23f65d8..51c8b9dd9 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -11,6 +11,7 @@ import json import random import hashlib import node_helpers +import logging from comfy.cli_args import args from comfy.comfy_types import FileLocator @@ -364,6 +365,216 @@ class RecordAudio: return (audio, ) +class TrimAudioDuration: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio": ("AUDIO",), + "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), + "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), + }, + } + + FUNCTION = "trim" + RETURN_TYPES = ("AUDIO",) + CATEGORY = "audio" + DESCRIPTION = "Trim audio tensor into chosen time range." + + def trim(self, audio, start_index, duration): + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + audio_length = waveform.shape[-1] + + if start_index < 0: + start_frame = audio_length + int(round(start_index * sample_rate)) + else: + start_frame = int(round(start_index * sample_rate)) + start_frame = max(0, min(start_frame, audio_length - 1)) + + end_frame = start_frame + int(round(duration * sample_rate)) + end_frame = max(0, min(end_frame, audio_length)) + + if start_frame >= end_frame: + raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") + + return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + + +class SplitAudioChannels: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio": ("AUDIO",), + }} + + RETURN_TYPES = ("AUDIO", "AUDIO") + RETURN_NAMES = ("left", "right") + FUNCTION = "separate" + CATEGORY = "audio" + DESCRIPTION = "Separates the audio into left and right channels." + + def separate(self, audio): + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + if waveform.shape[1] != 2: + raise ValueError("AudioSplit: Input audio has only one channel.") + + left_channel = waveform[..., 0:1, :] + right_channel = waveform[..., 1:2, :] + + return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + +def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): + if sample_rate_1 != sample_rate_2: + if sample_rate_1 > sample_rate_2: + waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) + output_sample_rate = sample_rate_1 + logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.") + else: + waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2) + output_sample_rate = sample_rate_2 + logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.") + else: + output_sample_rate = sample_rate_1 + return waveform_1, waveform_2, output_sample_rate + + +class AudioConcat: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio1": ("AUDIO",), + "audio2": ("AUDIO",), + "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "concat" + CATEGORY = "audio" + DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." + + def concat(self, audio1, audio2, direction): + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + if waveform_1.shape[1] == 1: + waveform_1 = waveform_1.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.") + if waveform_2.shape[1] == 1: + waveform_2 = waveform_2.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.") + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + if direction == 'after': + concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) + elif direction == 'before': + concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) + + return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + + +class AudioMerge: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio1": ("AUDIO",), + "audio2": ("AUDIO",), + "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), + }, + } + + FUNCTION = "merge" + RETURN_TYPES = ("AUDIO",) + CATEGORY = "audio" + DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." + + def merge(self, audio1, audio2, merge_method): + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + length_1 = waveform_1.shape[-1] + length_2 = waveform_2.shape[-1] + + if length_2 > length_1: + logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.") + waveform_2 = waveform_2[..., :length_1] + elif length_2 < length_1: + logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.") + pad_shape = list(waveform_2.shape) + pad_shape[-1] = length_1 - length_2 + pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) + waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1) + + if merge_method == "add": + waveform = waveform_1 + waveform_2 + elif merge_method == "subtract": + waveform = waveform_1 - waveform_2 + elif merge_method == "multiply": + waveform = waveform_1 * waveform_2 + elif merge_method == "mean": + waveform = (waveform_1 + waveform_2) / 2 + + max_val = waveform.abs().max() + if max_val > 1.0: + waveform = waveform / max_val + + return ({"waveform": waveform, "sample_rate": output_sample_rate},) + + +class AudioAdjustVolume: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio": ("AUDIO",), + "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "adjust_volume" + CATEGORY = "audio" + + def adjust_volume(self, audio, volume): + if volume == 0: + return (audio,) + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + gain = 10 ** (volume / 20) + waveform = waveform * gain + + return ({"waveform": waveform, "sample_rate": sample_rate},) + + +class EmptyAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), + "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), + "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "create_empty_audio" + CATEGORY = "audio" + + def create_empty_audio(self, duration, sample_rate, channels): + num_samples = int(round(duration * sample_rate)) + waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) + return ({"waveform": waveform, "sample_rate": sample_rate},) + + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, @@ -375,6 +586,12 @@ NODE_CLASS_MAPPINGS = { "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, "RecordAudio": RecordAudio, + "TrimAudioDuration": TrimAudioDuration, + "SplitAudioChannels": SplitAudioChannels, + "AudioConcat": AudioConcat, + "AudioMerge": AudioMerge, + "AudioAdjustVolume": AudioAdjustVolume, + "EmptyAudio": EmptyAudio, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -387,4 +604,10 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", "RecordAudio": "Record Audio", + "TrimAudioDuration": "Trim Audio Duration", + "SplitAudioChannels": "Split Audio Channels", + "AudioConcat": "Audio Concat", + "AudioMerge": "Audio Merge", + "AudioAdjustVolume": "Audio Adjust Volume", + "EmptyAudio": "Empty Audio", } From fccab99ec0fcd13e80fa59bc73bccff31f9450ca Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:09:42 -0700 Subject: [PATCH 240/325] Fix issue with .view() in HuMo. (#10014) --- comfy/ldm/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 2dac5980c..54616e6eb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1355,7 +1355,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention): x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) - x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) + x = x.transpose(1, 2).reshape(b, -1, n * d) x = self.o(x) return x From c8d2117f02bcad6d8316ffd8273bdc27adf83b44 Mon Sep 17 00:00:00 2001 From: Guy Niv <43928922+guyniv@users.noreply.github.com> Date: Thu, 25 Sep 2025 05:35:12 +0300 Subject: [PATCH 241/325] Fix memory leak by properly detaching model finalizer (#9979) When unloading models in load_models_gpu(), the model finalizer was not being explicitly detached, leading to a memory leak. This caused linear memory consumption increase over time as models are repeatedly loaded and unloaded. This change prevents orphaned finalizer references from accumulating in memory during model switching operations. --- comfy/model_management.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d880f1970..c5b817b62 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model.model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload for i in to_unload: - current_loaded_models.pop(i).model.detach(unpatch_all=False) + model_to_unload = current_loaded_models.pop(i) + model_to_unload.model.detach(unpatch_all=False) + model_to_unload.model_finalizer.detach() total_memory_required = {} for loaded_model in models_to_load: From ce4cb2389c8ce63cf8735f200b8672a2c1be0950 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:20:13 -0700 Subject: [PATCH 242/325] Make LatentCompositeMasked work with basic video latents. (#10023) --- comfy_extras/nodes_mask.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 2b0f8dd5d..a5e405008 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -12,35 +12,38 @@ from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) if resize_source: - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear") source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) - x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) - y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier)) + y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier)) left, top = (x // multiplier, y // multiplier) - right, bottom = (left + source.shape[3], top + source.shape[2],) + right, bottom = (left + source.shape[-1], top + source.shape[-2],) if mask is None: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] + if mask.ndim < source.ndim: + mask = mask.unsqueeze(1) + inverse_mask = torch.ones_like(mask) - mask - source_portion = mask * source[:, :, :visible_height, :visible_width] - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + source_portion = mask * source[..., :visible_height, :visible_width] + destination_portion = inverse_mask * destination[..., top:bottom, left:right] - destination[:, :, top:bottom, left:right] = source_portion + destination_portion + destination[..., top:bottom, left:right] = source_portion + destination_portion return destination class LatentCompositeMasked: From 2b7f9a8196304badb5fe58e5c734e4b182ad0fdf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:12:43 -0700 Subject: [PATCH 243/325] Fix the failing unit test. (#10037) --- .github/workflows/test-unit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 78c918031..00caf5b8a 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -10,7 +10,7 @@ jobs: test: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest, windows-2022, macos-latest] runs-on: ${{ matrix.os }} continue-on-error: true steps: From c4a46e943c12c7f3f6ac72f8fb51caad514ec9b6 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:08:16 -0700 Subject: [PATCH 244/325] Add @kosinkadink as code owner (#10041) Updated CODEOWNERS to include @kosinkadink as a code owner. --- CODEOWNERS | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index c8acd66d5..b7aca9b26 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,25 +1,3 @@ # Admins * @comfyanonymous - -# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org. -# Inlined the team members for now. - -# Maintainers -*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill - -# Python web server -/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill -/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill -/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill - -# Node developers -/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill -/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill -/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +* @kosinkadink From 76eb1d72c3e5bef51d6ca8a26bf996972d3f6d1a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:10:49 +0300 Subject: [PATCH 245/325] convert nodes_rebatch.py to V3 schema (#9945) --- comfy_extras/nodes_rebatch.py | 97 ++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index e29cb9ed1..5f4e82aef 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -1,18 +1,25 @@ +from typing_extensions import override import torch -class LatentRebatch: +from comfy_api.latest import ComfyExtension, io + + +class LatentRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) - - FUNCTION = "rebatch" - - CATEGORY = "latent/batch" + def define_schema(cls): + return io.Schema( + node_id="RebatchLatents", + display_name="Rebatch Latents", + category="latent/batch", + is_input_list=True, + inputs=[ + io.Latent.Input("latents"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(is_output_list=True), + ], + ) @staticmethod def get_batch(latents, list_ind, offset): @@ -53,7 +60,8 @@ class LatentRebatch: result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] return result - def rebatch(self, latents, batch_size): + @classmethod + def execute(cls, latents, batch_size): batch_size = batch_size[0] output_list = [] @@ -63,24 +71,24 @@ class LatentRebatch: for i in range(len(latents)): # fetch new entry of list #samples, masks, indices = self.get_batch(latents, i) - next_batch = self.get_batch(latents, i, processed) + next_batch = cls.get_batch(latents, i, processed) processed += len(next_batch[2]) # set to current if current is None if current_batch[0] is None: current_batch = next_batch # add previous to list if dimensions do not match elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) current_batch = next_batch # cat if everything checks out else: - current_batch = self.cat_batch(current_batch, next_batch) + current_batch = cls.cat_batch(current_batch, next_batch) # add to list if dimensions gone above target batch size if current_batch[0].shape[0] > batch_size: num = current_batch[0].shape[0] // batch_size - sliced, remainder = self.slice_batch(current_batch, num, batch_size) + sliced, remainder = cls.slice_batch(current_batch, num, batch_size) for i in range(num): output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) @@ -89,7 +97,7 @@ class LatentRebatch: #add remainder if current_batch[0] is not None: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) #get rid of empty masks @@ -97,23 +105,27 @@ class LatentRebatch: if s['noise_mask'].mean() == 1.0: del s['noise_mask'] - return (output_list,) + return io.NodeOutput(output_list) -class ImageRebatch: +class ImageRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) + def define_schema(cls): + return io.Schema( + node_id="RebatchImages", + display_name="Rebatch Images", + category="image/batch", + is_input_list=True, + inputs=[ + io.Image.Input("images"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Image.Output(is_output_list=True), + ], + ) - FUNCTION = "rebatch" - - CATEGORY = "image/batch" - - def rebatch(self, images, batch_size): + @classmethod + def execute(cls, images, batch_size): batch_size = batch_size[0] output_list = [] @@ -125,14 +137,17 @@ class ImageRebatch: for i in range(0, len(all_images), batch_size): output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) - return (output_list,) + return io.NodeOutput(output_list) -NODE_CLASS_MAPPINGS = { - "RebatchLatents": LatentRebatch, - "RebatchImages": ImageRebatch, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "RebatchLatents": "Rebatch Latents", - "RebatchImages": "Rebatch Images", -} +class RebatchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentRebatch, + ImageRebatch, + ] + + +async def comfy_entrypoint() -> RebatchExtension: + return RebatchExtension() From 7ea173c1873ec22df6edabc80a912a08ae2d521b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:12:04 +0300 Subject: [PATCH 246/325] convert nodes_fresca.py to V3 schema (#9951) --- comfy_extras/nodes_fresca.py | 61 +++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index 65c2d0d0e..f308eb0c1 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -1,6 +1,8 @@ # Code based on https://github.com/WikiChao/FreSca (MIT License) import torch import torch.fft as fft +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): @@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): return x_filtered -class FreSca: +class FreSca(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for low-frequency components"}), - "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for high-frequency components"}), - "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, - "tooltip": "Number of frequency indices around center to consider as low-frequency"}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Applies frequency-dependent scaling to the guidance" - def patch(self, model, scale_low, scale_high, freq_cutoff): + def define_schema(cls): + return io.Schema( + node_id="FreSca", + display_name="FreSca", + category="_for_testing", + description="Applies frequency-dependent scaling to the guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01, + tooltip="Scaling factor for low-frequency components"), + io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01, + tooltip="Scaling factor for high-frequency components"), + io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1, + tooltip="Number of frequency indices around center to consider as low-frequency"), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, scale_low, scale_high, freq_cutoff): def custom_cfg_function(args): conds_out = args["conds_out"] if len(conds_out) <= 1 or None in args["conds"][:2]: @@ -91,13 +99,16 @@ class FreSca: m = model.clone() m.set_model_sampler_pre_cfg_function(custom_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreSca": FreSca, -} +class FreScaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FreSca, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "FreSca": "FreSca", -} + +async def comfy_entrypoint() -> FreScaExtension: + return FreScaExtension() From 80718908a9ac1045ece84285ca568511dcc9bc46 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:12:38 +0300 Subject: [PATCH 247/325] convert nodes_sdupscale.py to V3 schema (#9943) --- comfy_extras/nodes_sdupscale.py | 54 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index bba67e8dd..31b373370 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -1,23 +1,31 @@ +from typing_extensions import override + import torch import comfy.utils +from comfy_api.latest import ComfyExtension, io -class SD_4XUpscale_Conditioning: +class SD_4XUpscale_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SD_4XUpscale_Conditioning", + category="conditioning/upscale_diffusion", + inputs=[ + io.Image.Input("images"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/upscale_diffusion" - - def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + @classmethod + def execute(cls, images, positive, negative, scale_ratio, noise_augmentation): width = max(1, round(images.shape[-2] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio)) @@ -39,8 +47,16 @@ class SD_4XUpscale_Conditioning: out_cn.append(n) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) - return (out_cp, out_cn, {"samples":latent}) + return io.NodeOutput(out_cp, out_cn, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, -} + +class SdUpscaleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SD_4XUpscale_Conditioning, + ] + + +async def comfy_entrypoint() -> SdUpscaleExtension: + return SdUpscaleExtension() From a061b06321b4e91d05c7c436b1e9b188360c5377 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:13:05 +0300 Subject: [PATCH 248/325] convert nodes_tcfg.py to V3 schema (#9942) --- comfy_extras/nodes_tcfg.py | 51 +++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/comfy_extras/nodes_tcfg.py b/comfy_extras/nodes_tcfg.py index 35b89a73f..1a6767770 100644 --- a/comfy_extras/nodes_tcfg.py +++ b/comfy_extras/nodes_tcfg.py @@ -1,8 +1,9 @@ # TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) +from typing_extensions import override import torch -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy_api.latest import ComfyExtension, io def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: @@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) -class TCFG(ComfyNodeABC): +class TCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="TCFG", + display_name="Tangential Damping CFG", + category="advanced/guidance", + description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + ) - RETURN_TYPES = (IO.MODEL,) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - - CATEGORY = "advanced/guidance" - DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality." - - def patch(self, model): + @classmethod + def execute(cls, model): m = model.clone() def tangential_damping_cfg(args): @@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC): return [cond_pred, uncond_pred_td] + conds_out[2:] m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TCFG": TCFG, -} +class TcfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "TCFG": "Tangential Damping CFG", -} + +async def comfy_entrypoint() -> TcfgExtension: + return TcfgExtension() From d20576e6a3527d0763ba8d7a72c70ee66829690a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:13:52 +0300 Subject: [PATCH 249/325] convert nodes_sag.py to V3 schema (#9940) --- comfy_extras/nodes_sag.py | 50 +++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1bd8d7364..0f47db30b 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -2,10 +2,13 @@ import torch from torch import einsum import torch.nn.functional as F import math +from typing_extensions import override from einops import rearrange, repeat from comfy.ldm.modules.attention import optimized_attention import comfy.samplers +from comfy_api.latest import ComfyExtension, io + # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output @@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma): img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img -class SelfAttentionGuidance: +class SelfAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="SelfAttentionGuidance", + display_name="Self-Attention Guidance", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) - CATEGORY = "_for_testing" - - def patch(self, model, scale, blur_sigma): + @classmethod + def execute(cls, model, scale, blur_sigma): m = model.clone() attn_scores = None @@ -170,12 +180,16 @@ class SelfAttentionGuidance: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "SelfAttentionGuidance": SelfAttentionGuidance, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SelfAttentionGuidance": "Self-Attention Guidance", -} +class SagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SelfAttentionGuidance, + ] + + +async def comfy_entrypoint() -> SagExtension: + return SagExtension() From 2103e393350d297ef77497a1b14a8199d4a1f1b4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:14:42 +0300 Subject: [PATCH 250/325] convert nodes_post_processing to V3 schema (#9491) --- comfy_extras/nodes_post_processing.py | 249 ++++++++++++-------------- 1 file changed, 111 insertions(+), 138 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ed7a07152..34c388a5a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,3 +1,4 @@ +from typing_extensions import override import numpy as np import torch import torch.nn.functional as F @@ -7,33 +8,27 @@ import math import comfy.utils import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class Blend: - def __init__(self): - pass +class Blend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlend", + category="image/postprocessing", + inputs=[ + io.Image.Input("image1"), + io.Image.Input("image2"), + io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01), + io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "image2": ("IMAGE",), - "blend_factor": ("FLOAT", { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01 - }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blend_images" - - CATEGORY = "image/postprocessing" - - def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput: image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: @@ -41,12 +36,13 @@ class Blend: image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = image2.permute(0, 2, 3, 1) - blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = cls.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) - return (blended_image,) + return io.NodeOutput(blended_image) - def blend_mode(self, img1, img2, mode): + @classmethod + def blend_mode(cls, img1, img2, mode): if mode == "normal": return img2 elif mode == "multiply": @@ -56,13 +52,13 @@ class Blend: elif mode == "overlay": return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": - return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1)) elif mode == "difference": return img1 - img2 - else: - raise ValueError(f"Unsupported blend mode: {mode}") + raise ValueError(f"Unsupported blend mode: {mode}") - def g(self, x): + @classmethod + def g(cls, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) def gaussian_kernel(kernel_size: int, sigma: float, device=None): @@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None): g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() -class Blur: - def __init__(self): - pass +class Blur(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlur", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "blur_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blur" - - CATEGORY = "image/postprocessing" - - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput: if blur_radius == 0: - return (image,) + return io.NodeOutput(image) image = image.to(comfy.model_management.get_torch_device()) batch_size, height, width, channels = image.shape @@ -115,31 +99,24 @@ class Blur: blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) - return (blurred.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device())) -class Quantize: - def __init__(self): - pass +class Quantize(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "colors": ("INT", { - "default": 256, - "min": 1, - "max": 256, - "step": 1 - }), - "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "quantize" - - CATEGORY = "image/postprocessing" + def define_schema(cls): + return io.Schema( + node_id="ImageQuantize", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("colors", default=256, min=1, max=256, step=1), + io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @staticmethod def bayer(im, pal_im, order): @@ -167,7 +144,8 @@ class Quantize: im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) return im - def quantize(self, image: torch.Tensor, colors: int, dither: str): + @classmethod + def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput: batch_size, height, width, _ = image.shape result = torch.zeros_like(image) @@ -187,46 +165,29 @@ class Quantize: quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 result[b] = quantized_array - return (result,) + return io.NodeOutput(result) -class Sharpen: - def __init__(self): - pass +class Sharpen(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageSharpen", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01), + io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "sharpen_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.01 - }), - "alpha": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "sharpen" - - CATEGORY = "image/postprocessing" - - def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): + def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput: if sharpen_radius == 0: - return (image,) + return io.NodeOutput(image) batch_size, height, width, channels = image.shape image = image.to(comfy.model_management.get_torch_device()) @@ -245,23 +206,29 @@ class Sharpen: result = torch.clamp(sharpened, 0, 1) - return (result.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(result.to(comfy.model_management.intermediate_device())) -class ImageScaleToTotalPixels: +class ImageScaleToTotalPixels(io.ComfyNode): upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return io.Schema( + node_id="ImageScaleToTotalPixels", + category="image/upscaling", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("upscale_method", options=cls.upscale_methods), + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, megapixels): + @classmethod + def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: samples = image.movedim(-1,1) total = int(megapixels * 1024 * 1024) @@ -271,12 +238,18 @@ class ImageScaleToTotalPixels: s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1,-1) - return (s,) + return io.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageBlend": Blend, - "ImageBlur": Blur, - "ImageQuantize": Quantize, - "ImageSharpen": Sharpen, - "ImageScaleToTotalPixels": ImageScaleToTotalPixels, -} +class PostProcessingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Blend, + Blur, + Quantize, + Sharpen, + ImageScaleToTotalPixels, + ] + +async def comfy_entrypoint() -> PostProcessingExtension: + return PostProcessingExtension() From cd66d72b464fd9d344baa426b50a5f0e5e512f99 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:15:44 +0300 Subject: [PATCH 251/325] convert CLIPTextEncodeSDXL nodes to V3 schema (#9716) --- comfy_extras/nodes_clip_sdxl.py | 93 +++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 14269caf3..520ff0e3c 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -1,43 +1,52 @@ -from nodes import MAX_RESOLUTION +from typing_extensions import override -class CLIPTextEncodeSDXLRefiner: +import nodes +from comfy_api.latest import ComfyExtension, io + + +class CLIPTextEncodeSDXLRefiner(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXLRefiner", + category="advanced/conditioning", + inputs=[ + io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, ascore, width, height, text): + @classmethod + def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput: tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height})) -class CLIPTextEncodeSDXL: +class CLIPTextEncodeSDXL(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXL", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text_g", multiline=True, dynamic_prompts=True), + io.String.Input("text_l", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): + @classmethod + def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput: tokens = clip.tokenize(text_g) tokens["l"] = clip.tokenize(text_l)["l"] if len(tokens["l"]) != len(tokens["g"]): @@ -46,9 +55,17 @@ class CLIPTextEncodeSDXL: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height})) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, - "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, -} + +class ClipSdxlExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeSDXLRefiner, + CLIPTextEncodeSDXL, + ] + + +async def comfy_entrypoint() -> ClipSdxlExtension: + return ClipSdxlExtension() From 1e098d61327e1c02c1a47b2626514474aa8e3c7e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 26 Sep 2025 15:34:17 -0700 Subject: [PATCH 252/325] Don't add template to qwen2.5vl when template is in prompt. (#10043) Make the hunyuan image refiner template_end 36. --- comfy/text_encoders/hunyuan_image.py | 8 ++++- comfy/text_encoders/qwen_image.py | 46 +++++++++++++++++----------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index 699eddc33..ff04726e1 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -63,7 +63,13 @@ class HunyuanImageTEModel(QwenImageTEModel): self.byt5_small = None def encode_token_weights(self, token_weight_pairs): - cond, p, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + template_end = -1 + if tok_pairs[0][0] == 27: + if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end + template_end = 36 + + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end) if self.byt5_small is not None and "byt5" in token_weight_pairs: out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) extra["conditioning_byt5small"] = out[0] diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index 6646b1003..40fa67937 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): - if llama_template is None: - if len(images) > 0: - llama_text = self.llama_template_images.format(text) - else: - llama_text = self.llama_template.format(text) + skip_template = False + if text.startswith('<|im_start|>'): + skip_template = True + if text.startswith('<|start_header_id|>'): + skip_template = True + + if skip_template: + llama_text = text else: - llama_text = llama_template.format(text) + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) key_name = next(iter(tokens)) embed_count = 0 @@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) - def encode_token_weights(self, token_weight_pairs): + def encode_token_weights(self, token_weight_pairs, template_end=-1): out, pooled, extra = super().encode_token_weights(token_weight_pairs) tok_pairs = token_weight_pairs["qwen25_7b"][0] count_im_start = 0 - for i, v in enumerate(tok_pairs): - elem = v[0] - if not torch.is_tensor(elem): - if isinstance(elem, numbers.Integral): - if elem == 151644 and count_im_start < 2: - template_end = i - count_im_start += 1 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 - if out.shape[1] > (template_end + 3): - if tok_pairs[template_end + 1][0] == 872: - if tok_pairs[template_end + 2][0] == 198: - template_end += 3 + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 out = out[:, template_end:] From 196954ab8c55bc4ac48113686a57ce250677c7b5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 26 Sep 2025 19:55:03 -0700 Subject: [PATCH 253/325] Add 'input_cond' and 'input_uncond' to the args dictionary passed into sampler_cfg_function (#10044) --- comfy/samplers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b3202cec6..c59e296a1 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -360,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale @@ -390,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option for fn in model_options.get("sampler_pre_cfg_function", []): args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "model": model, "model_options": model_options} - out = fn(args) + out = fn(args) return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) From 0572029fee48741a8cf34a8e4d485898c5ab5dfd Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 27 Sep 2025 12:18:16 +0800 Subject: [PATCH 254/325] Update template to 0.1.88 (#10046) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2980bebdd..b3f81e8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.26.13 -comfyui-workflow-templates==0.1.86 +comfyui-workflow-templates==0.1.88 comfyui-embedded-docs==0.2.6 torch torchsde From 255572188f79e5c58fa997bf73529021129459a9 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 26 Sep 2025 21:29:13 -0700 Subject: [PATCH 255/325] Add workflow templates version tracking to system_stats (#9089) Adds installed and required workflow templates version information to the /system_stats endpoint, allowing the frontend to detect and notify users when their templates package is outdated. - Add get_installed_templates_version() and get_required_templates_version() methods to FrontendManager - Include templates version info in system_stats response - Add comprehensive unit tests for the new functionality --- app/frontend_management.py | 33 +++++++++ server.py | 4 ++ tests-unit/app_test/frontend_manager_test.py | 71 ++++++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/app/frontend_management.py b/app/frontend_management.py index 0bee73685..cce0c117d 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -42,6 +42,7 @@ def get_installed_frontend_version(): frontend_version_str = version("comfyui-frontend-package") return frontend_version_str + def get_required_frontend_version(): """Get the required frontend version from requirements.txt.""" try: @@ -63,6 +64,7 @@ def get_required_frontend_version(): logging.error(f"Error reading requirements.txt: {e}") return None + def check_frontend_version(): """Check if the frontend version is up to date.""" @@ -203,6 +205,37 @@ class FrontendManager: """Get the required frontend package version.""" return get_required_frontend_version() + @classmethod + def get_installed_templates_version(cls) -> str: + """Get the currently installed workflow templates package version.""" + try: + templates_version_str = version("comfyui-workflow-templates") + return templates_version_str + except Exception: + return None + + @classmethod + def get_required_templates_version(cls) -> str: + """Get the required workflow templates version from requirements.txt.""" + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("comfyui-workflow-templates=="): + version_str = line.split("==")[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid templates version format in requirements.txt: {version_str}") + return None + return version_str + logging.error("comfyui-workflow-templates not found in requirements.txt") + return None + except FileNotFoundError: + logging.error("requirements.txt not found. Cannot determine required templates version.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None + @classmethod def default_frontend_path(cls) -> str: try: diff --git a/server.py b/server.py index 603677397..80e9d3fa7 100644 --- a/server.py +++ b/server.py @@ -550,6 +550,8 @@ class PromptServer(): vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() + installed_templates_version = FrontendManager.get_installed_templates_version() + required_templates_version = FrontendManager.get_required_templates_version() system_stats = { "system": { @@ -558,6 +560,8 @@ class PromptServer(): "ram_free": ram_free, "comfyui_version": __version__, "required_frontend_version": required_frontend_version, + "installed_templates_version": installed_templates_version, + "required_templates_version": required_templates_version, "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index ce43ac564..643f04e72 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -205,3 +205,74 @@ numpy""" # Assert assert version is None + + +def test_get_templates_version(): + # Arrange + expected_version = "0.1.41" + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +comfyui-workflow-templates==0.1.41 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version == expected_version + + +def test_get_templates_version_not_found(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_templates_version_invalid_semver(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-workflow-templates==1.0.0.beta +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_installed_templates_version(): + # Arrange + expected_version = "0.1.40" + + # Act + with patch("app.frontend_management.version", return_value=expected_version): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version == expected_version + + +def test_get_installed_templates_version_not_installed(): + # Act + with patch("app.frontend_management.version", side_effect=Exception("Package not found")): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version is None From a9cf1cd249773632949bec2262f921f64378127f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 09:13:05 +0300 Subject: [PATCH 256/325] convert nodes_hidream.py to V3 schema (#9946) --- comfy_extras/nodes_hidream.py | 88 +++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index dfb98597b..eee683ee1 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -1,55 +1,73 @@ +from typing_extensions import override + import folder_paths import comfy.sd import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class QuadrupleCLIPLoader: +class QuadrupleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="QuadrupleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ] + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" - - def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) + return io.NodeOutput(clip) -class CLIPTextEncodeHiDream: +class CLIPTextEncodeHiDream(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, llama): + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHiDream", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): tokens = clip.tokenize(clip_g) tokens["l"] = clip.tokenize(clip_l)["l"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["llama"] = clip.tokenize(llama)["llama"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "QuadrupleCLIPLoader": QuadrupleCLIPLoader, - "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, -} + +class HiDreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + QuadrupleCLIPLoader, + CLIPTextEncodeHiDream, + ] + + +async def comfy_entrypoint() -> HiDreamExtension: + return HiDreamExtension() From 6b4b671ce7b6c412c2db9f9f83ff8e27dbcfd959 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 12:27:01 +0300 Subject: [PATCH 257/325] convert nodes_bfl.py to V3 schema (#10033) --- comfy_api_nodes/nodes_bfl.py | 1056 ++++++++++++++++------------------ 1 file changed, 489 insertions(+), 567 deletions(-) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index c09be8d5b..77914021d 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -2,7 +2,8 @@ import asyncio import io from inspect import cleandoc from typing import Union, Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api_nodes.apis.bfl_api import ( BFLStatus, BFLFluxExpandImageRequest, @@ -130,7 +131,7 @@ def convert_image_to_base64(image: torch.Tensor): return base64.b64encode(img_byte_arr.getvalue()).decode() -class FluxProUltraImageNode(ComfyNodeABC): +class FluxProUltraImageNode(comfy_io.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ @@ -141,71 +142,67 @@ class FluxProUltraImageNode(ComfyNodeABC): MAXIMUM_RATIO_STR = "4:1" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProUltraImageNode", + display_name="Flux 1.1 [pro] Ultra Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, + comfy_io.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", ), - "raw": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "When True, generate less processed, more natural-looking images.", - }, + comfy_io.Boolean.Input( + "raw", + default=False, + tooltip="When True, generate less processed, more natural-looking images.", ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), - "image_prompt_strength": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Blend between the prompt and the image prompt.", - }, + comfy_io.Image.Input( + "image_prompt", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Float.Input( + "image_prompt_strength", + default=0.1, + min=0.0, + max=1.0, + step=0.01, + tooltip="Blend between the prompt and the image prompt.", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): + def validate_inputs(cls, aspect_ratio: str): try: validate_aspect_ratio( aspect_ratio, @@ -218,14 +215,9 @@ class FluxProUltraImageNode(ComfyNodeABC): return str(e) return True - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, prompt_upsampling=False, @@ -233,9 +225,7 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=0, image_prompt=None, image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( @@ -251,10 +241,10 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=seed, aspect_ratio=validate_aspect_ratio( aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, + minimum_ratio=cls.MINIMUM_RATIO, + maximum_ratio=cls.MAXIMUM_RATIO, + minimum_ratio_str=cls.MINIMUM_RATIO_STR, + maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ), raw=raw, image_prompt=( @@ -266,13 +256,16 @@ class FluxProUltraImageNode(ComfyNodeABC): None if image_prompt is None else round(image_prompt_strength, 2) ), ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxKontextProImageNode(ComfyNodeABC): +class FluxKontextProImageNode(comfy_io.ComfyNode): """ Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ @@ -283,81 +276,73 @@ class FluxKontextProImageNode(ComfyNodeABC): MAXIMUM_RATIO_STR = "4:1" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation - specify what and how to edit.", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation - specify what and how to edit.", ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, + comfy_io.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", ), - "guidance": ( - IO.FLOAT, - { - "default": 3.0, - "min": 0.1, - "max": 99.0, - "step": 0.1, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=3.0, + min=0.1, + max=99.0, + step=0.1, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 1, - "max": 150, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=1, + max=150, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 1234, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=1234, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - }, - "optional": { - "input_image": (IO.IMAGE,), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" + comfy_io.Image.Input( + "input_image", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" + NODE_ID = "FluxKontextProImageNode" + DISPLAY_NAME = "Flux.1 Kontext [pro] Image" - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, guidance: float, @@ -365,21 +350,19 @@ class FluxKontextProImageNode(ComfyNodeABC): input_image: Optional[torch.Tensor]=None, seed=0, prompt_upsampling=False, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: aspect_ratio = validate_aspect_ratio( aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, + minimum_ratio=cls.MINIMUM_RATIO, + maximum_ratio=cls.MAXIMUM_RATIO, + minimum_ratio_str=cls.MINIMUM_RATIO_STR, + maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ) if input_image is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( endpoint=ApiEndpoint( - path=self.BFL_PATH, + path=cls.BFL_PATH, method=HttpMethod.POST, request_model=BFLFluxKontextProGenerateRequest, response_model=BFLFluxProGenerateResponse, @@ -397,10 +380,13 @@ class FluxKontextProImageNode(ComfyNodeABC): else convert_image_to_base64(input_image) ) ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -410,63 +396,60 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DESCRIPTION = cleandoc(__doc__ or "") BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" + NODE_ID = "FluxKontextMaxImageNode" + DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(ComfyNodeABC): +class FluxProImageNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProImageNode", + display_name="Flux 1.1 [pro] Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "width": ( - IO.INT, - { - "default": 1024, - "min": 256, - "max": 1440, - "step": 32, - }, + comfy_io.Int.Input( + "width", + default=1024, + min=256, + max=1440, + step=32, ), - "height": ( - IO.INT, - { - "default": 768, - "min": 256, - "max": 1440, - "step": 32, - }, + comfy_io.Int.Input( + "height", + default=768, + min=256, + max=1440, + step=32, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.Image.Input( + "image_prompt", + optional=True, ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), # "image_prompt_strength": ( # IO.FLOAT, # { @@ -477,22 +460,19 @@ class FluxProImageNode(ComfyNodeABC): # "tooltip": "Blend between the prompt and the image prompt.", # }, # ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, prompt_upsampling, width: int, @@ -500,9 +480,7 @@ class FluxProImageNode(ComfyNodeABC): seed=0, image_prompt=None, # image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: image_prompt = ( image_prompt if image_prompt is None @@ -524,118 +502,103 @@ class FluxProImageNode(ComfyNodeABC): seed=seed, image_prompt=image_prompt, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProExpandNode(ComfyNodeABC): +class FluxProExpandNode(comfy_io.ComfyNode): """ Outpaints image based on prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProExpandNode", + display_name="Flux.1 Expand Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "top": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the top of the image" - }, + comfy_io.Int.Input( + "top", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the top of the image", ), - "bottom": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the bottom of the image" - }, + comfy_io.Int.Input( + "bottom", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the bottom of the image", ), - "left": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the left side of the image" - }, + comfy_io.Int.Input( + "left", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the left of the image", ), - "right": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the right side of the image" - }, + comfy_io.Int.Input( + "right", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the right of the image", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -646,9 +609,7 @@ class FluxProExpandNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: image = convert_image_to_base64(image) operation = SynchronousOperation( @@ -670,84 +631,77 @@ class FluxProExpandNode(ComfyNodeABC): seed=seed, image=image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProFillNode(ComfyNodeABC): +class FluxProFillNode(comfy_io.ComfyNode): """ Inpaints image based on mask and prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "mask": (IO.MASK,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProFillNode", + display_name="Flux.1 Fill Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.Mask.Input("mask"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -755,9 +709,7 @@ class FluxProFillNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) mask = convert_image_to_base64(convert_mask_to_image(mask)) @@ -780,109 +732,96 @@ class FluxProFillNode(ComfyNodeABC): image=image, mask=mask, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProCannyNode(ComfyNodeABC): +class FluxProCannyNode(comfy_io.ComfyNode): """ Generate image using a control image (canny). """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProCannyNode", + display_name="Flux.1 Canny Control Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("control_image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "canny_low_threshold": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True" - }, + comfy_io.Float.Input( + "canny_low_threshold", + default=0.1, + min=0.01, + max=0.99, + step=0.01, + tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", ), - "canny_high_threshold": ( - IO.FLOAT, - { - "default": 0.4, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True" - }, + comfy_io.Float.Input( + "canny_high_threshold", + default=0.4, + min=0.01, + max=0.99, + step=0.01, + tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - }, + comfy_io.Boolean.Input( + "skip_preprocessing", + default=False, + tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", ), - "guidance": ( - IO.FLOAT, - { - "default": 30, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=30, + min=1, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, control_image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -892,9 +831,7 @@ class FluxProCannyNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: control_image = convert_image_to_base64(control_image[:, :, :, :3]) preprocessed_image = None @@ -929,89 +866,80 @@ class FluxProCannyNode(ComfyNodeABC): canny_high_threshold=canny_high_threshold, preprocessed_image=preprocessed_image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProDepthNode(ComfyNodeABC): +class FluxProDepthNode(comfy_io.ComfyNode): """ Generate image using a control image (depth). """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProDepthNode", + display_name="Flux.1 Depth Control Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("control_image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - }, + comfy_io.Boolean.Input( + "skip_preprocessing", + default=False, + tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", ), - "guidance": ( - IO.FLOAT, - { - "default": 15, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=15, + min=1, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, control_image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -1019,9 +947,7 @@ class FluxProDepthNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: control_image = convert_image_to_base64(control_image[:,:,:,:3]) preprocessed_image = None @@ -1045,33 +971,29 @@ class FluxProDepthNode(ComfyNodeABC): control_image=control_image, preprocessed_image=preprocessed_image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "FluxProUltraImageNode": FluxProUltraImageNode, - # "FluxProImageNode": FluxProImageNode, - "FluxKontextProImageNode": FluxKontextProImageNode, - "FluxKontextMaxImageNode": FluxKontextMaxImageNode, - "FluxProExpandNode": FluxProExpandNode, - "FluxProFillNode": FluxProFillNode, - "FluxProCannyNode": FluxProCannyNode, - "FluxProDepthNode": FluxProDepthNode, -} +class BFLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + FluxProUltraImageNode, + # FluxProImageNode, + FluxKontextProImageNode, + FluxKontextMaxImageNode, + FluxProExpandNode, + FluxProFillNode, + FluxProCannyNode, + FluxProDepthNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image", - # "FluxProImageNode": "Flux 1.1 [pro] Image", - "FluxKontextProImageNode": "Flux.1 Kontext [pro] Image", - "FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image", - "FluxProExpandNode": "Flux.1 Expand Image", - "FluxProFillNode": "Flux.1 Fill Image", - "FluxProCannyNode": "Flux.1 Canny Control Image", - "FluxProDepthNode": "Flux.1 Depth Control Image", -} + +async def comfy_entrypoint() -> BFLExtension: + return BFLExtension() From bcfd80dd79ccfa77a7da69380795fbb55b65b1ba Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 12:28:11 +0300 Subject: [PATCH 258/325] convert nodes_luma.py to V3 schema (#10030) --- comfy_api_nodes/nodes_luma.py | 774 +++++++++++++++++----------------- 1 file changed, 396 insertions(+), 378 deletions(-) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index b3c32bed5..9cd02ffd2 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,7 +1,8 @@ from __future__ import annotations from inspect import cleandoc from typing import Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis.luma_api import ( LumaImageModel, @@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration): def video_result_url_extractor(response: LumaGeneration): return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None -class LumaReferenceNode(ComfyNodeABC): +class LumaReferenceNode(comfy_io.ComfyNode): """ Holds an image and weight for use with Luma Generate Image node. """ - RETURN_TYPES = (LumaIO.LUMA_REF,) - RETURN_NAMES = ("luma_ref",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_luma_reference" - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaReferenceNode", + display_name="Luma Reference", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="Image to use as reference.", + ), + comfy_io.Float.Input( + "weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of image reference.", + ), + comfy_io.Custom(LumaIO.LUMA_REF).Input( + "luma_ref", + optional=True, + ), + ], + outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as reference.", - }, - ), - "weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of image reference.", - }, - ), - }, - "optional": {"luma_ref": (LumaIO.LUMA_REF,)}, - } - - def create_luma_reference( - self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ): + def execute( + cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None + ) -> comfy_io.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: luma_ref = LumaReferenceChain() luma_ref.add(LumaReference(image=image, weight=round(weight, 2))) - return (luma_ref,) + return comfy_io.NodeOutput(luma_ref) -class LumaConceptsNode(ComfyNodeABC): +class LumaConceptsNode(comfy_io.ComfyNode): """ Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. """ - RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,) - RETURN_NAMES = ("luma_concepts",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_concepts" - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaConceptsNode", + display_name="Luma Concepts", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "concept1", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept2", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept3", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept4", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to add to the ones chosen here.", + optional=True, + ), + ], + outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "concept1": (get_luma_concepts(include_none=True),), - "concept2": (get_luma_concepts(include_none=True),), - "concept3": (get_luma_concepts(include_none=True),), - "concept4": (get_luma_concepts(include_none=True),), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to add to the ones chosen here." - }, - ), - }, - } - - def create_concepts( - self, + def execute( + cls, concept1: str, concept2: str, concept3: str, concept4: str, luma_concepts: LumaConceptChain = None, - ): + ) -> comfy_io.NodeOutput: chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) if luma_concepts is not None: chain = luma_concepts.clone_and_merge(chain) - return (chain,) + return comfy_io.NodeOutput(chain) -class LumaImageGenerationNode(ComfyNodeABC): +class LumaImageGenerationNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and aspect ratio. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageNode", + display_name="Luma Text to Image", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + comfy_io.Combo.Input( + "model", + options=[model.value for model in LumaImageModel], + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[ratio.value for ratio in LumaAspectRatio], + default=LumaAspectRatio.ratio_16_9, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Float.Input( + "style_image_weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of style image. Ignored if no style_image provided.", + ), + comfy_io.Custom(LumaIO.LUMA_REF).Input( + "image_luma_ref", + tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.", + optional=True, + ), + comfy_io.Image.Input( + "style_image", + tooltip="Style reference image; only 1 image will be used.", + optional=True, + ), + comfy_io.Image.Input( + "character_image", + tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - "style_image_weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of style image. Ignored if no style_image provided.", - }, - ), - }, - "optional": { - "image_luma_ref": ( - LumaIO.LUMA_REF, - { - "tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered." - }, - ), - "style_image": ( - IO.IMAGE, - {"tooltip": "Style reference image; only 1 image will be used."}, - ), - "character_image": ( - IO.IMAGE, - { - "tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, @@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC): image_luma_ref: LumaReferenceChain = None, style_image: torch.Tensor = None, character_image: torch.Tensor = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = await self._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=kwargs, + api_image_ref = await cls._convert_luma_refs( + image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs, ) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = await self._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=kwargs, + api_style_ref = await cls._convert_style_image( + style_image, weight=style_image_weight, auth_kwargs=auth_kwargs, ) # handle character_ref images character_ref = None if character_image is not None: download_urls = await upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=kwargs, + character_image, max_images=4, auth_kwargs=auth_kwargs, ) character_ref = LumaCharacterRef( identity0=LumaImageIdentity(images=download_urls) @@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() @@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.image) as img_response: img = process_image_response(await img_response.content.read()) - return (img,) + return comfy_io.NodeOutput(img) + @classmethod async def _convert_luma_refs( - self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None + cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None ): luma_urls = [] ref_count = 0 @@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC): break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) + @classmethod async def _convert_style_image( - self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None + cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None ): chain = LumaReferenceChain( first_ref=LumaReference(image=style_image, weight=weight) ) - return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) -class LumaImageModifyNode(ComfyNodeABC): +class LumaImageModifyNode(comfy_io.ComfyNode): """ Modifies images synchronously based on prompt and aspect ratio. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageModifyNode", + display_name="Luma Image to Image", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + comfy_io.Float.Input( + "image_weight", + default=0.1, + min=0.0, + max=0.98, + step=0.01, + tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.", + ), + comfy_io.Combo.Input( + "model", + options=[model.value for model in LumaImageModel], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "image_weight": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 0.98, - "step": 0.01, - "tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, image: torch.Tensor, image_weight: float, seed, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # first, upload image download_urls = await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, + image, max_images=1, auth_kwargs=auth_kwargs, ) image_url = download_urls[0] # next, make Luma call with download url provided @@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC): url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() @@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.image) as img_response: img = process_image_response(await img_response.content.read()) - return (img,) + return comfy_io.NodeOutput(img) -class LumaTextToVideoGenerationNode(ComfyNodeABC): +class LumaTextToVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos synchronously based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaVideoNode", + display_name="Luma Text to Video", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "model", + options=[model.value for model in LumaVideoModel], + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[ratio.value for ratio in LumaAspectRatio], + default=LumaAspectRatio.ratio_16_9, + ), + comfy_io.Combo.Input( + "resolution", + options=[resolution.value for resolution in LumaVideoOutputResolution], + default=LumaVideoOutputResolution.res_540p, + ), + comfy_io.Combo.Input( + "duration", + options=[dur.value for dur in LumaVideoModelOutputDuration], + ), + comfy_io.Boolean.Input( + "loop", + default=False, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ) + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, @@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop: bool, seed, luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/luma/generations", @@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=video_result_url_extractor, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.video) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class LumaImageToVideoGenerationNode(ComfyNodeABC): +class LumaImageToVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos synchronously based on prompt, input images, and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageToVideoNode", + display_name="Luma Image to Video", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "model", + options=[model.value for model in LumaVideoModel], + ), + # comfy_io.Combo.Input( + # "aspect_ratio", + # options=[ratio.value for ratio in LumaAspectRatio], + # default=LumaAspectRatio.ratio_16_9, + # ), + comfy_io.Combo.Input( + "resolution", + options=[resolution.value for resolution in LumaVideoOutputResolution], + default=LumaVideoOutputResolution.res_540p, + ), + comfy_io.Combo.Input( + "duration", + options=[dur.value for dur in LumaVideoModelOutputDuration], + ), + comfy_io.Boolean.Input( + "loop", + default=False, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Image.Input( + "first_image", + tooltip="First frame of generated video.", + optional=True, + ), + comfy_io.Image.Input( + "last_image", + tooltip="Last frame of generated video.", + optional=True, + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ) + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - # "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], { - # "default": LumaAspectRatio.ratio_16_9, - # }), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "first_image": ( - IO.IMAGE, - {"tooltip": "First frame of generated video."}, - ), - "last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}), - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, resolution: str, @@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): first_image: torch.Tensor = None, last_image: torch.Tensor = None, luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: if first_image is None and last_image is None: raise Exception( "At least one of first_image and last_image requires an input." ) - keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None @@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=video_result_url_extractor, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.video) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + @classmethod async def _convert_to_keyframes( - self, + cls, first_image: torch.Tensor = None, last_image: torch.Tensor = None, auth_kwargs: Optional[dict[str,str]] = None, @@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): return LumaKeyframes(frame0=frame0, frame1=frame1) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "LumaImageNode": LumaImageGenerationNode, - "LumaImageModifyNode": LumaImageModifyNode, - "LumaVideoNode": LumaTextToVideoGenerationNode, - "LumaImageToVideoNode": LumaImageToVideoGenerationNode, - "LumaReferenceNode": LumaReferenceNode, - "LumaConceptsNode": LumaConceptsNode, -} +class LumaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + LumaImageGenerationNode, + LumaImageModifyNode, + LumaTextToVideoGenerationNode, + LumaImageToVideoGenerationNode, + LumaReferenceNode, + LumaConceptsNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "LumaImageNode": "Luma Text to Image", - "LumaImageModifyNode": "Luma Image to Image", - "LumaVideoNode": "Luma Text to Video", - "LumaImageToVideoNode": "Luma Image to Video", - "LumaReferenceNode": "Luma Reference", - "LumaConceptsNode": "Luma Concepts", -} + +async def comfy_entrypoint() -> LumaExtension: + return LumaExtension() From ad5aef2d0c8517e971129db1dfb0d0108d8341a8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 12:34:32 +0300 Subject: [PATCH 259/325] convert nodes_pixart.py to V3 schema (#10019) --- comfy_extras/nodes_pixart.py | 52 +++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index 8d9276afe..a23e87b1f 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -1,24 +1,38 @@ -from nodes import MAX_RESOLUTION +from typing_extensions import override +import nodes +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodePixArtAlpha: +class CLIPTextEncodePixArtAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodePixArtAlpha", + category="advanced/conditioning", + description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", + inputs=[ + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - CATEGORY = "advanced/conditioning" - DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma." - - def encode(self, clip, width, height, text): + @classmethod + def execute(cls, clip, width, height, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height})) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, -} + +class PixArtExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodePixArtAlpha, + ] + +async def comfy_entrypoint() -> PixArtExtension: + return PixArtExtension() From 7eca95657cf7a70c15d598c969b890a164a300a1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 12:36:43 +0300 Subject: [PATCH 260/325] convert nodes_photomaker.py to V3 schema (#10017) --- comfy_extras/nodes_photomaker.py | 74 ++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index d358ed6d5..228183c07 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -4,6 +4,8 @@ import folder_paths import comfy.clip_model import comfy.clip_vision import comfy.ops +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 VISION_CONFIG_DICT = { @@ -116,41 +118,52 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): return updated_prompt_embeds -class PhotoMakerLoader: +class PhotoMakerLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerLoader", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("PHOTOMAKER",) - FUNCTION = "load_photomaker_model" - - CATEGORY = "_for_testing/photomaker" - - def load_photomaker_model(self, photomaker_model_name): + @classmethod + def execute(cls, photomaker_model_name): photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: data = data["id_encoder"] photomaker_model.load_state_dict(data) - return (photomaker_model,) + return io.NodeOutput(photomaker_model) -class PhotoMakerEncode: +class PhotoMakerEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker": ("PHOTOMAKER",), - "image": ("IMAGE",), - "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), - }} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerEncode", + category="_for_testing/photomaker", + inputs=[ + io.Photomaker.Input("photomaker"), + io.Image.Input("image"), + io.Clip.Input("clip"), + io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "apply_photomaker" - - CATEGORY = "_for_testing/photomaker" - - def apply_photomaker(self, photomaker, image, clip, text): + @classmethod + def execute(cls, photomaker, image, clip, text): special_token = "photomaker" pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() try: @@ -178,11 +191,16 @@ class PhotoMakerEncode: else: out = cond - return ([[out, {"pooled_output": pooled}]], ) + return io.NodeOutput([[out, {"pooled_output": pooled}]]) -NODE_CLASS_MAPPINGS = { - "PhotoMakerLoader": PhotoMakerLoader, - "PhotoMakerEncode": PhotoMakerEncode, -} +class PhotomakerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PhotoMakerLoader, + PhotoMakerEncode, + ] +async def comfy_entrypoint() -> PhotomakerExtension: + return PhotomakerExtension() From 160698eb418269d64fbbe8c34db27a4d1ddb0540 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 27 Sep 2025 22:25:35 +0300 Subject: [PATCH 261/325] convert nodes_qwen.py to V3 schema (#10049) --- comfy_extras/nodes_qwen.py | 88 ++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 49747dc7a..525239ae5 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,24 +1,29 @@ import node_helpers import comfy.utils import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class TextEncodeQwenImageEdit: +class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }, - "optional": {"vae": ("VAE", ), - "image": ("IMAGE", ),}} + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, prompt, vae=None, image=None): + @classmethod + def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput: ref_latent = None if image is None: images = [] @@ -40,28 +45,30 @@ class TextEncodeQwenImageEdit: conditioning = clip.encode_from_tokens_scheduled(tokens) if ref_latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -class TextEncodeQwenImageEditPlus: +class TextEncodeQwenImageEditPlus(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }, - "optional": {"vae": ("VAE", ), - "image1": ("IMAGE", ), - "image2": ("IMAGE", ), - "image3": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None): + @classmethod + def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: ref_latents = [] images = [image1, image2, image3] images_vl = [] @@ -94,10 +101,17 @@ class TextEncodeQwenImageEditPlus: conditioning = clip.encode_from_tokens_scheduled(tokens) if len(ref_latents) > 0: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -NODE_CLASS_MAPPINGS = { - "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, - "TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus, -} +class QwenExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeQwenImageEdit, + TextEncodeQwenImageEditPlus, + ] + + +async def comfy_entrypoint() -> QwenExtension: + return QwenExtension() From 653ceab4148a9fbc050ebceb674acef760792b77 Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Sun, 28 Sep 2025 08:14:16 +1000 Subject: [PATCH 262/325] Reduce Peak WAN inference VRAM usage - part II (#10062) * flux: math: Use _addcmul to avoid expensive VRAM intermediate The rope process can be the VRAM peak and this intermediate for the addition result before releasing the original can OOM. addcmul_ it. * wan: Delete the self attention before cross attention This saves VRAM when the cross attention and FFN are in play as the VRAM peak. --- comfy/ldm/flux/math.py | 5 ++++- comfy/ldm/wan/model.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index fb7cd7586..8deda0d4a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 54616e6eb..0dc650ced 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module): freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) From 40ae495ddcbc04846e91ccad3e844bb34d98c6fd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 27 Sep 2025 17:28:49 -0700 Subject: [PATCH 263/325] Improvements to the stable release workflow. (#10065) --- .github/workflows/stable-release.yml | 39 ++++++++++++------- .../windows_release_dependencies.yml | 3 +- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 2bc8e5905..b39b42acd 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -8,11 +8,11 @@ on: description: 'Git tag' required: true type: string - cu: - description: 'CUDA version' + cache_tag: + description: 'Cached dependencies tag' required: true type: string - default: "129" + default: "cu129" python_minor: description: 'Python minor version' required: true @@ -23,7 +23,11 @@ on: required: true type: string default: "6" - + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" jobs: package_comfy_windows: @@ -42,15 +46,15 @@ jobs: id: cache with: path: | - cu${{ inputs.cu }}_python_deps.tar + ${{ inputs.cache_tag }}_python_deps.tar update_comfyui_and_python_dependencies.bat - key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} - shell: bash run: | - mv cu${{ inputs.cu }}_python_deps.tar ../ + mv ${{ inputs.cache_tag }}_python_deps.tar ../ mv update_comfyui_and_python_dependencies.bat ../ cd .. - tar xf cu${{ inputs.cu }}_python_deps.tar + tar xf ${{ inputs.cache_tag }}_python_deps.tar pwd ls @@ -65,12 +69,19 @@ jobs: echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* + ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* + + grep comfyui ../ComfyUI/requirements.txt ./requirements_comfyui.txt + ./python.exe -s -m pip install -r requirements_comfyui.txt + rm requirements_comfyui.txt + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space - rm ./Lib/site-packages/torch/lib/libprotoc.lib - rm ./Lib/site-packages/torch/lib/libprotobuf.lib + if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + fi cd .. @@ -91,7 +102,7 @@ jobs: cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z + mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}.7z cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu @@ -104,7 +115,7 @@ jobs: uses: svenstaro/upload-release-action@v2 with: repo_token: ${{ secrets.GITHUB_TOKEN }} - file: ComfyUI_windows_portable_nvidia.7z + file: ComfyUI_windows_portable_${{ inputs.rel_name }}.7z tag: ${{ inputs.git_tag }} overwrite: true draft: true diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 7761cc1ed..f1e2946e6 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -56,7 +56,8 @@ jobs: ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 pause" > update_comfyui_and_python_dependencies.bat - python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir From 896f2e653c02769371e113906d70a24306d87a58 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 27 Sep 2025 18:30:35 -0700 Subject: [PATCH 264/325] Fix typo in release workflow. (#10066) --- .github/workflows/stable-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index b39b42acd..619b0e995 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -71,7 +71,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt ./requirements_comfyui.txt + grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt From a1127b232d221432be065f8e765f3538e62a2f41 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 28 Sep 2025 05:11:36 +0300 Subject: [PATCH 265/325] convert nodes_lotus.py to V3 schema (#10057) --- comfy_extras/nodes_lotus.py | 42 +++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py index 739dbdd3d..9f62ba2bf 100644 --- a/comfy_extras/nodes_lotus.py +++ b/comfy_extras/nodes_lotus.py @@ -1,20 +1,22 @@ +from typing_extensions import override + import torch import comfy.model_management as mm +from comfy_api.latest import ComfyExtension, io -class LotusConditioning: + +class LotusConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - } + def define_schema(cls): + return io.Schema( + node_id="LotusConditioning", + category="conditioning/lotus", + inputs=[], + outputs=[io.Conditioning.Output(display_name="conditioning")], + ) - RETURN_TYPES = ("CONDITIONING",) - RETURN_NAMES = ("conditioning",) - FUNCTION = "conditioning" - CATEGORY = "conditioning/lotus" - - def conditioning(self): + @classmethod + def execute(cls) -> io.NodeOutput: device = mm.get_torch_device() #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors @@ -22,8 +24,16 @@ class LotusConditioning: cond = [[prompt_embeds, {}]] - return (cond,) + return io.NodeOutput(cond) -NODE_CLASS_MAPPINGS = { - "LotusConditioning" : LotusConditioning, -} + +class LotusExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LotusConditioning, + ] + + +async def comfy_entrypoint() -> LotusExtension: + return LotusExtension() From 1cf86f5ae5706ff141f8d51ed9ba96ecdcdcb695 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 28 Sep 2025 05:12:51 +0300 Subject: [PATCH 266/325] convert nodes_lumina2.py to V3 schema (#10058) --- comfy_extras/nodes_lumina2.py | 99 +++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 275189785..89ff2397a 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -1,20 +1,27 @@ -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from typing_extensions import override import torch +from comfy_api.latest import ComfyExtension, io -class RenormCFG: + +class RenormCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}), - "renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="RenormCFG", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01), + io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model" - - def patch(self, model, cfg_trunc, renorm_cfg): + @classmethod + def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput: def renorm_cfg_func(args): cond_denoised = args["cond_denoised"] uncond_denoised = args["uncond_denoised"] @@ -53,10 +60,10 @@ class RenormCFG: m = model.clone() m.set_model_sampler_cfg_function(renorm_cfg_func) - return (m, ) + return io.NodeOutput(m) -class CLIPTextEncodeLumina2(ComfyNodeABC): +class CLIPTextEncodeLumina2(io.ComfyNode): SYSTEM_PROMPT = { "superior": "You are an assistant designed to generate superior images with the superior "\ "degree of image-text alignment based on textual prompts or user prompts.", @@ -69,36 +76,52 @@ class CLIPTextEncodeLumina2(ComfyNodeABC): "Alignment: You are an assistant designed to generate high-quality images with the highest "\ "degree of image-text alignment based on textual prompts." @classmethod - def INPUT_TYPES(s) -> InputTypeDict: - return { - "required": { - "system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}), - "user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) - } - } - RETURN_TYPES = (IO.CONDITIONING,) - OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeLumina2", + display_name="CLIP Text Encode for Lumina2", + category="conditioning", + description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " + "that can be used to guide the diffusion model towards generating specific images.", + inputs=[ + io.Combo.Input( + "system_prompt", + options=list(cls.SYSTEM_PROMPT.keys()), + tooltip=cls.SYSTEM_PROMPT_TIP, + ), + io.String.Input( + "user_prompt", + multiline=True, + dynamic_prompts=True, + tooltip="The text to be encoded.", + ), + io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), + ], + outputs=[ + io.Conditioning.Output( + tooltip="A conditioning containing the embedded text used to guide the diffusion model.", + ), + ], + ) - CATEGORY = "conditioning" - DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." - - def encode(self, clip, user_prompt, system_prompt): + @classmethod + def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput: if clip is None: raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt] + system_prompt = cls.SYSTEM_PROMPT[system_prompt] prompt = f'{system_prompt} {user_prompt}' tokens = clip.tokenize(prompt) - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeLumina2": CLIPTextEncodeLumina2, - "RenormCFG": RenormCFG -} +class Lumina2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeLumina2, + RenormCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2", -} +async def comfy_entrypoint() -> Lumina2Extension: + return Lumina2Extension() From 2dadb348602f8f452eb2a1d8720f6029dc4039a2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 28 Sep 2025 05:16:22 +0300 Subject: [PATCH 267/325] convert nodes_hypertile.py to V3 schema (#10061) --- comfy_extras/nodes_hypertile.py | 59 +++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index b366117c7..0ad5e6773 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -1,9 +1,11 @@ #Taken from: https://github.com/tfernd/HyperTile/ import math +from typing_extensions import override from einops import rearrange # Use torch rng for consistency across generations from torch import randint +from comfy_api.latest import ComfyExtension, io def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) @@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: return ns[idx] -class HyperTile: +class HyperTile(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), - "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), - "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), - "scale_depth": ("BOOLEAN", {"default": False}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="HyperTile", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("tile_size", default=256, min=1, max=2048), + io.Int.Input("swap_size", default=2, min=1, max=128), + io.Int.Input("max_depth", default=0, min=0, max=10), + io.Boolean.Input("scale_depth", default=False), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + @classmethod + def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput: latent_tile_size = max(32, tile_size) // 8 - self.temp = None + temp = None def hypertile_in(q, k, v, extra_options): + nonlocal temp model_chans = q.shape[-2] orig_shape = extra_options['original_shape'] apply_to = [] @@ -58,14 +66,15 @@ class HyperTile: if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) - self.temp = (nh, nw, h, w) + temp = (nh, nw, h, w) return q, k, v return q, k, v def hypertile_out(out, extra_options): - if self.temp is not None: - nh, nw, h, w = self.temp - self.temp = None + nonlocal temp + if temp is not None: + nh, nw, h, w = temp + temp = None out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) return out @@ -76,6 +85,14 @@ class HyperTile: m.set_model_attn1_output_patch(hypertile_out) return (m, ) -NODE_CLASS_MAPPINGS = { - "HyperTile": HyperTile, -} + +class HyperTileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HyperTile, + ] + + +async def comfy_entrypoint() -> HyperTileExtension: + return HyperTileExtension() From 1364548c721a466adcdc60e49ee291b0d4255245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Wang=20=28=E7=8E=8B=E7=91=9E=29?= Date: Sun, 28 Sep 2025 10:36:02 +0800 Subject: [PATCH 268/325] feat: ComfyUI can be run on the specified Ascend NPU (#9663) * feature: Set the Ascend NPU to use a single one * Enable the `--cuda-device` parameter to support both CUDA and Ascend NPUs simultaneously. * Make the code just set the ASCENT_RT_VISIBLE_DEVICES environment variable without any other edits to master branch --------- Co-authored-by: Jedrzej Kosinski --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index c33f0e17b..70696fcc3 100644 --- a/main.py +++ b/main.py @@ -127,6 +127,7 @@ if __name__ == "__main__": if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: From 555f902fc1ed20e98201f9102172f0fc190c2c42 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 27 Sep 2025 19:43:25 -0700 Subject: [PATCH 269/325] Fix stable workflow creating multiple draft releases. (#10067) --- .github/workflows/stable-release.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 619b0e995..924bdec90 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -112,10 +112,9 @@ jobs: ls - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 + uses: softprops/action-gh-release@v2 with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - file: ComfyUI_windows_portable_${{ inputs.rel_name }}.7z - tag: ${{ inputs.git_tag }} - overwrite: true + files: ComfyUI_windows_portable_${{ inputs.rel_name }}.7z + tag_name: ${{ inputs.git_tag }} draft: true + overwrite_files: true From b60dc316272ba139e06b8a7b2f5f5b622c9afe20 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 28 Sep 2025 10:41:32 -0700 Subject: [PATCH 270/325] Update command to install latest nighly pytorch. (#10085) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3f6cfc2ed..5a257687b 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ Nvidia users should install stable pytorch using this command: This is the command to install pytorch nightly instead which might have performance improvements. -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130``` #### Troubleshooting From 6ec1cfe101206229ff3af5c3d3675b3b92477067 Mon Sep 17 00:00:00 2001 From: Changrz <51637999+WhiteGiven@users.noreply.github.com> Date: Tue, 30 Sep 2025 02:59:12 +0800 Subject: [PATCH 271/325] [Rodin3d api nodes] Updated the name of the save file path (changed from timestamp to UUID). (#10011) * Update savepath name from time to uuid * delete lib --- comfy_api_nodes/nodes_rodin.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 1af393eba..817efb0f5 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -11,7 +11,6 @@ from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths import aiohttp import os -import datetime import asyncio import io import logging @@ -243,8 +242,8 @@ class Rodin3DAPI: return mesh_mode, quality_override - async def download_files(self, url_list): - save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + async def download_files(self, url_list, task_uuid): + save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") os.makedirs(save_path, exist_ok=True) model_file_path = None async with aiohttp.ClientSession() as session: @@ -320,7 +319,7 @@ class Rodin3D_Regular(Rodin3DAPI): **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + model = await self.download_files(download_list, task_uuid) return (model,) @@ -366,7 +365,7 @@ class Rodin3D_Detail(Rodin3DAPI): **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + model = await self.download_files(download_list, task_uuid) return (model,) @@ -412,7 +411,7 @@ class Rodin3D_Smooth(Rodin3DAPI): **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + model = await self.download_files(download_list, task_uuid) return (model,) @@ -467,7 +466,7 @@ class Rodin3D_Sketch(Rodin3DAPI): ) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + model = await self.download_files(download_list, task_uuid) return (model,) From c8276f8c6bee54b494fd5bec8dfb87ed21a3fa65 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 30 Sep 2025 02:59:42 +0800 Subject: [PATCH 272/325] Update template to 0.1.91 (#10096) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b3f81e8fa..45d3e1607 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.26.13 -comfyui-workflow-templates==0.1.88 +comfyui-workflow-templates==0.1.91 comfyui-embedded-docs==0.2.6 torch torchsde From 05a258efd84bfb00e2618eb9b7937b8fef1e82ed Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 29 Sep 2025 22:01:04 +0300 Subject: [PATCH 273/325] add WanImageToImageApi node (#10094) --- comfy_api_nodes/nodes_wan.py | 149 ++++++++++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index db5bd41c1..0be5daadb 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -28,6 +28,12 @@ class Text2ImageInputField(BaseModel): negative_prompt: Optional[str] = Field(None) +class Image2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + images: list[str] = Field(..., min_length=1, max_length=2) + + class Text2VideoInputField(BaseModel): prompt: str = Field(...) negative_prompt: Optional[str] = Field(None) @@ -49,6 +55,13 @@ class Txt2ImageParametersField(BaseModel): watermark: bool = Field(True) +class Image2ImageParametersField(BaseModel): + size: Optional[str] = Field(None) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(True) + + class Text2VideoParametersField(BaseModel): size: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) @@ -73,6 +86,12 @@ class Text2ImageTaskCreationRequest(BaseModel): parameters: Txt2ImageParametersField = Field(...) +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2ImageInputField = Field(...) + parameters: Image2ImageParametersField = Field(...) + + class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2VideoInputField = Field(...) @@ -135,7 +154,12 @@ async def process_task( url: str, request_model: Type[T], response_model: Type[R], - payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + payload: Union[ + Text2ImageTaskCreationRequest, + Image2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, + Image2VideoTaskCreationRequest, + ], node_id: str, estimated_duration: int, poll_interval: int, @@ -288,6 +312,128 @@ class WanTextToImageApi(comfy_io.ComfyNode): return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) +class WanImageToImageApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanImageToImageApi", + display_name="Wan Image to Image", + category="api node/image/Wan", + description="Generates an image from one or two input images and a text prompt. " + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-i2i-preview"], + default="wan2.5-i2i-preview", + tooltip="Model to use.", + ), + comfy_io.Image.Input( + "image", + tooltip="Single-image editing or multi-image fusion, maximum 2 images.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + # redo this later as an optional combo of recommended resolutions + # comfy_io.Int.Input( + # "width", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + # comfy_io.Int.Input( + # "height", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + negative_prompt: str = "", + # width: int = 1024, + # height: int = 1024, + seed: int = 0, + watermark: bool = True, + ): + n_images = get_number_of_images(image) + if n_images not in (1, 2): + raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") + images = [] + for i in image: + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) + payload = Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", + request_model=Image2ImageTaskCreationRequest, + response_model=ImageTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=42, + poll_interval=3, + ) + return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + class WanTextToVideoApi(comfy_io.ComfyNode): @classmethod def define_schema(cls): @@ -593,6 +739,7 @@ class WanApiExtension(ComfyExtension): async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: return [ WanTextToImageApi, + WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, ] From b1111c2062ce35d4292bcd94f27c099a13c619cb Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 29 Sep 2025 22:03:35 +0300 Subject: [PATCH 274/325] convert nodes_mochi.py to V3 schema (#10069) --- comfy_extras/nodes_mochi.py | 49 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py index 1c474faa9..d750194fc 100644 --- a/comfy_extras/nodes_mochi.py +++ b/comfy_extras/nodes_mochi.py @@ -1,23 +1,40 @@ -import nodes +from typing_extensions import override import torch import comfy.model_management +import nodes +from comfy_api.latest import ComfyExtension, io -class EmptyMochiLatentVideo: + +class EmptyMochiLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyMochiLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples": latent}) -NODE_CLASS_MAPPINGS = { - "EmptyMochiLatentVideo": EmptyMochiLatentVideo, -} + +class MochiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyMochiLatentVideo, + ] + + +async def comfy_entrypoint() -> MochiExtension: + return MochiExtension() From 041b8824f50e01803637d5e83c3f4edaf628f43a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 29 Sep 2025 22:05:28 +0300 Subject: [PATCH 275/325] convert nodes_perpneg.py to V3 schema (#10081) --- comfy_extras/nodes_perpneg.py | 93 +++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 89e5eef90..cd068ce9c 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -5,6 +5,9 @@ import comfy.samplers import comfy.utils import node_helpers import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): pos = noise_pred_pos - noise_pred_nocond @@ -16,20 +19,27 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co return cfg_result #TODO: This node should be removed, it has been replaced with PerpNegGuider -class PerpNeg: +class PerpNeg(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="PerpNeg", + display_name="Perp-Neg (DEPRECATED by PerpNegGuider)", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + is_deprecated=True, + ) - CATEGORY = "_for_testing" - DEPRECATED = True - - def patch(self, model, empty_conditioning, neg_scale): + @classmethod + def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput: m = model.clone() nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) @@ -50,7 +60,7 @@ class PerpNeg: m.set_model_sampler_cfg_function(cfg_function) - return (m, ) + return io.NodeOutput(m) class Guider_PerpNeg(comfy.samplers.CFGGuider): @@ -112,35 +122,42 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider): return cfg_result -class PerpNegGuider: +class PerpNegGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "empty_conditioning": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerpNegGuider", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Guider.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "_for_testing" - - def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + @classmethod + def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput: guider = Guider_PerpNeg(model) guider.set_conds(positive, negative, empty_conditioning) guider.set_cfg(cfg, neg_scale) - return (guider,) + return io.NodeOutput(guider) -NODE_CLASS_MAPPINGS = { - "PerpNeg": PerpNeg, - "PerpNegGuider": PerpNegGuider, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", -} +class PerpNegExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerpNeg, + PerpNegGuider, + ] + + +async def comfy_entrypoint() -> PerpNegExtension: + return PerpNegExtension() From ed0f4a609b5e6821f97db5cb1715068c25f78e7b Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 29 Sep 2025 12:16:02 -0700 Subject: [PATCH 276/325] dont cache new locale entry points (#10101) --- middleware/cache_middleware.py | 11 ++++++----- tests-unit/server_test/test_cache_control.py | 7 +++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py index 374ef7934..f02135369 100644 --- a/middleware/cache_middleware.py +++ b/middleware/cache_middleware.py @@ -26,11 +26,12 @@ async def cache_control( """Cache control middleware that sets appropriate cache headers based on file type and response status""" response: web.Response = await handler(request) - if ( - request.path.endswith(".js") - or request.path.endswith(".css") - or request.path.endswith("index.json") - ): + path_filename = request.path.rsplit("/", 1)[-1] + is_entry_point = path_filename.startswith("index") and path_filename.endswith( + ".json" + ) + + if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: response.headers.setdefault("Cache-Control", "no-cache") return response diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py index 8de59125a..fa68d9408 100644 --- a/tests-unit/server_test/test_cache_control.py +++ b/tests-unit/server_test/test_cache_control.py @@ -48,6 +48,13 @@ CACHE_SCENARIOS = [ "expected_cache": "no-cache", "should_have_header": True, }, + { + "name": "localized_index_json_no_cache", + "path": "/templates/index.zh.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, # Non-matching files { "name": "html_no_header", From 8accf50908094d9cd39168981fa5394274d25491 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 29 Sep 2025 22:35:51 +0300 Subject: [PATCH 277/325] convert nodes_mahiro.py to V3 schema (#10070) --- comfy_extras/nodes_mahiro.py | 50 ++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 8fcdfba75..07b3353f4 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -1,17 +1,29 @@ +from typing_extensions import override import torch import torch.nn.functional as F -class Mahiro: +from comfy_api.latest import ComfyExtension, io + + +class Mahiro(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." - def patch(self, model): + def define_schema(cls): + return io.Schema( + node_id="Mahiro", + display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + category="_for_testing", + description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def mahiro_normd(args): scale: float = args['cond_scale'] @@ -30,12 +42,16 @@ class Mahiro: wm = (simsc*cfg + (4-simsc)*leap) / 4 return wm m.set_model_sampler_post_cfg_function(mahiro_normd) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "Mahiro": Mahiro -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", -} +class MahiroExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Mahiro, + ] + + +async def comfy_entrypoint() -> MahiroExtension: + return MahiroExtension() From 7f38e4c538de2fa38d0539c18577cdd0e5d251c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:27:52 -0700 Subject: [PATCH 278/325] Add action to create cached deps with manually specified torch. (#10102) --- .../windows_release_dependencies_manual.yml | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 .github/workflows/windows_release_dependencies_manual.yml diff --git a/.github/workflows/windows_release_dependencies_manual.yml b/.github/workflows/windows_release_dependencies_manual.yml new file mode 100644 index 000000000..0799feef1 --- /dev/null +++ b/.github/workflows/windows_release_dependencies_manual.yml @@ -0,0 +1,64 @@ +name: "Windows Release dependencies Manual" + +on: + workflow_dispatch: + inputs: + torch_dependencies: + description: 'torch dependencies' + required: false + type: string + default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128" + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu128" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "10" + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} + + - shell: bash + run: | + echo "@echo off + call update_comfyui.bat nopause + echo - + echo This will try to update pytorch and all python dependencies. + echo - + echo If you just want to update normally, close this and run update_comfyui.bat instead. + echo - + pause + ..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2 + pause" > update_comfyui_and_python_dependencies.bat + + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps + tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps + + - uses: actions/cache/save@v4 + with: + path: | + ${{ inputs.cache_tag }}_python_deps.tar + update_comfyui_and_python_dependencies.bat + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} From 1673ace19b9d63a8dc0d388aafdb54abf2497892 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:08:42 -0700 Subject: [PATCH 279/325] Make the final release test optional in the stable release action. (#10103) --- .github/workflows/stable-release.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 924bdec90..5eb4a0783 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -28,6 +28,11 @@ on: required: true type: string default: "nvidia" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true jobs: package_comfy_windows: @@ -104,6 +109,10 @@ jobs: "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}.7z + - shell: bash + if: ${{ inputs.test_release }} + run: | + cd .. cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu From 0db6aabed3942ea71258d25d32dc971a2a2421af Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:54:05 -0700 Subject: [PATCH 280/325] Different base files for different release. (#10104) --- .github/workflows/stable-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 5eb4a0783..40e1bc157 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -101,7 +101,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. From 375884842314a2234ddc29132b03c741ce81443b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:54:37 -0700 Subject: [PATCH 281/325] Different base files for nvidia and amd portables. (#10105) --- .../run_amd_gpu.bat} | 0 .../README_VERY_IMPORTANT.txt | 0 .../run_cpu.bat | 0 .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 2 ++ .../run_nvidia_gpu_fast_fp16_accumulation.bat | 0 .github/workflows/windows_release_nightly_pytorch.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 7 files changed, 4 insertions(+), 2 deletions(-) rename .ci/{windows_base_files/run_nvidia_gpu.bat => windows_amd_base_files/run_amd_gpu.bat} (100%) rename .ci/{windows_base_files => windows_nvidia_base_files}/README_VERY_IMPORTANT.txt (100%) rename .ci/{windows_base_files => windows_nvidia_base_files}/run_cpu.bat (100%) create mode 100755 .ci/windows_nvidia_base_files/run_nvidia_gpu.bat rename .ci/{windows_base_files => windows_nvidia_base_files}/run_nvidia_gpu_fast_fp16_accumulation.bat (100%) diff --git a/.ci/windows_base_files/run_nvidia_gpu.bat b/.ci/windows_amd_base_files/run_amd_gpu.bat similarity index 100% rename from .ci/windows_base_files/run_nvidia_gpu.bat rename to .ci/windows_amd_base_files/run_amd_gpu.bat diff --git a/.ci/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt similarity index 100% rename from .ci/windows_base_files/README_VERY_IMPORTANT.txt rename to .ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt diff --git a/.ci/windows_base_files/run_cpu.bat b/.ci/windows_nvidia_base_files/run_cpu.bat similarity index 100% rename from .ci/windows_base_files/run_cpu.bat rename to .ci/windows_nvidia_base_files/run_cpu.bat diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat new file mode 100755 index 000000000..274d7c948 --- /dev/null +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +pause diff --git a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat similarity index 100% rename from .ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat rename to .ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 5bdc940de..ca1ef71ae 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -68,7 +68,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp -r ComfyUI/.ci/windows_nightly_base_files/* ./ echo "call update_comfyui.bat nopause diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 46375698e..7955325fc 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -81,7 +81,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. From 342cf644ce495dafaa31dd49d42c47c5e242e701 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 17:05:44 -0700 Subject: [PATCH 282/325] Add a way to have different names for stable nvidia portables. (#10106) --- .github/workflows/stable-release.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 40e1bc157..1cbbfbf69 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -28,6 +28,11 @@ on: required: true type: string default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" test_release: description: 'Test Release' required: true @@ -107,7 +112,7 @@ jobs: cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}.7z + mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z - shell: bash if: ${{ inputs.test_release }} @@ -123,7 +128,7 @@ jobs: - name: Upload binaries to release uses: softprops/action-gh-release@v2 with: - files: ComfyUI_windows_portable_${{ inputs.rel_name }}.7z + files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z tag_name: ${{ inputs.git_tag }} draft: true overwrite_files: true From bed4b49d08d80e195cb42d5294037fc6b631942e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 17:31:15 -0700 Subject: [PATCH 283/325] Add action to do the full stable release. (#10107) --- .github/workflows/release-stable-all.yml | 49 ++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/release-stable-all.yml diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml new file mode 100644 index 000000000..aac84d637 --- /dev/null +++ b/.github/workflows/release-stable-all.yml @@ -0,0 +1,49 @@ +name: "Release Stable All Portable Versions" + +on: + workflow_dispatch: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + +jobs: + release_nvidia_default: + name: "Release NVIDIA Default (cu129)" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu129" + python_minor: "13" + python_patch: "6" + rel_name: "nvidia" + rel_extra_name: "" + test_release: true + secrets: inherit + + release_nvidia_cu128: + name: "Release NVIDIA cu128" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu128" + python_minor: "12" + python_patch: "10" + rel_name: "nvidia" + rel_extra_name: "_cu128" + test_release: true + secrets: inherit + + release_amd_rocm: + name: "Release AMD ROCm 6.4.4" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "rocm644" + python_minor: "12" + python_patch: "10" + rel_name: "amd" + rel_extra_name: "" + test_release: false + secrets: inherit From 447884b65740d9f4160ef13d55adb49ca111140e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 17:37:51 -0700 Subject: [PATCH 284/325] Make stable release workflow callable. (#10108) --- .github/workflows/stable-release.yml | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 1cbbfbf69..28484a9d1 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -2,6 +2,42 @@ name: "Release Stable Version" on: + workflow_call: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu129" + python_minor: + description: 'Python minor version' + required: true + type: string + default: "13" + python_patch: + description: 'Python patch version' + required: true + type: string + default: "6" + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true workflow_dispatch: inputs: git_tag: From 414a178fb690ef9998f65419f03ef1a83cf559de Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:03:02 -0700 Subject: [PATCH 285/325] Add basic readme for AMD portable. (#10109) --- .../README_VERY_IMPORTANT.txt | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100755 .ci/windows_amd_base_files/README_VERY_IMPORTANT.txt diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt new file mode 100755 index 000000000..570ac3398 --- /dev/null +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -0,0 +1,24 @@ +As of the time of writing this you need this preview driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html + +HOW TO RUN: + +if you have a AMD gpu: + +run_amd_gpu.bat + + +IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints + +You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors + + +RECOMMENDED WAY TO UPDATE: +To update the ComfyUI code: update\update_comfyui.bat + + +TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: +In the ComfyUI directory you will find a file: extra_model_paths.yaml.example +Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + + From 977a4ed8c55ade53d0d6cfe1fe8a6396ee35a2ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 Sep 2025 23:04:42 -0400 Subject: [PATCH 286/325] ComfyUI version 0.3.61 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index d469a8194..737b72131 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.60" +__version__ = "0.3.61" diff --git a/pyproject.toml b/pyproject.toml index 7340c320b..e851560f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.60" +version = "0.3.61" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 6e079abc3a3fc0fb98e2a0848877874151310ed1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:11:37 -0700 Subject: [PATCH 287/325] Workflow permission fix. (#10110) --- .github/workflows/release-stable-all.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index aac84d637..5c1024599 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -10,6 +10,10 @@ on: jobs: release_nvidia_default: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" name: "Release NVIDIA Default (cu129)" uses: ./.github/workflows/stable-release.yml with: @@ -23,6 +27,10 @@ jobs: secrets: inherit release_nvidia_cu128: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" name: "Release NVIDIA cu128" uses: ./.github/workflows/stable-release.yml with: @@ -36,6 +44,10 @@ jobs: secrets: inherit release_amd_rocm: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" name: "Release AMD ROCm 6.4.4" uses: ./.github/workflows/stable-release.yml with: From f48d7230de2f7b10fe8bfda3d7f53241d19c7266 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 30 Sep 2025 09:17:49 -0700 Subject: [PATCH 288/325] Add new portable links to readme. (#10112) --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 5a257687b..8f24a33ee 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you If you have trouble extracting it, right click the file -> properties -> unblock +#### Alternative Downloads: + +[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) + +[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. From 631b9ae861bf8bdd3c538da232e4c8938448e59d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 30 Sep 2025 20:21:47 +0300 Subject: [PATCH 289/325] fix(Rodin3D-Gen2): missing "task_uuid" parameter (#10128) --- comfy_api_nodes/nodes_rodin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 817efb0f5..633ac46d3 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -540,7 +540,7 @@ class Rodin3D_Gen2(Rodin3DAPI): **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + model = await self.download_files(download_list, task_uuid) return (model,) From b682a73c55a6434fdd9293d45ace969597f8ad65 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 30 Sep 2025 20:43:41 +0300 Subject: [PATCH 290/325] enable Seedance Pro model in the FirstLastFrame node (#10120) --- comfy_api_nodes/nodes_bytedance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index a7eeaf15a..654d6a362 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], + options=[model.value for model in Image2VideoModelName], default=Image2VideoModelName.seedance_1_lite.value, tooltip="Model name", ), From bab8ba20bf47d985d6b1d73627c2add76bd4e716 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 Sep 2025 15:12:07 -0400 Subject: [PATCH 291/325] ComfyUI version 0.3.62. --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 737b72131..ac76fbe35 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.61" +__version__ = "0.3.62" diff --git a/pyproject.toml b/pyproject.toml index e851560f7..d0a76c6d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.61" +version = "0.3.62" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From c4a8cf60ab5d6eaf052b7a08f5ee97104acf7a2f Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Tue, 30 Sep 2025 22:12:32 -0700 Subject: [PATCH 292/325] Bump frontend to 1.27.7 (#10133) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 45d3e1607..588c5dcf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.26.13 +comfyui-frontend-package==1.27.7 comfyui-workflow-templates==0.1.91 comfyui-embedded-docs==0.2.6 torch From 638097829d2352a1c78ab4fbb1e028d1e7cff012 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:00:22 +0300 Subject: [PATCH 293/325] convert nodes_audio_encoder.py to V3 schema (#10123) --- comfy_api/latest/_io.py | 1 + comfy_extras/nodes_audio_encoder.py | 68 ++++++++++++++++++----------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4826818df..2d95cffd6 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1605,6 +1605,7 @@ class _IO: Model = Model ClipVision = ClipVision ClipVisionOutput = ClipVisionOutput + AudioEncoder = AudioEncoder AudioEncoderOutput = AudioEncoderOutput StyleModel = StyleModel Gligen = Gligen diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py index 39a140fef..13aacd41a 100644 --- a/comfy_extras/nodes_audio_encoder.py +++ b/comfy_extras/nodes_audio_encoder.py @@ -1,44 +1,62 @@ import folder_paths import comfy.audio_encoders.audio_encoders import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class AudioEncoderLoader: +class AudioEncoderLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), - }} - RETURN_TYPES = ("AUDIO_ENCODER",) - FUNCTION = "load_model" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderLoader", + category="loaders", + inputs=[ + io.Combo.Input( + "audio_encoder_name", + options=folder_paths.get_filename_list("audio_encoders"), + ), + ], + outputs=[io.AudioEncoder.Output()], + ) - CATEGORY = "loaders" - - def load_model(self, audio_encoder_name): + @classmethod + def execute(cls, audio_encoder_name) -> io.NodeOutput: audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) if audio_encoder is None: raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") - return (audio_encoder,) + return io.NodeOutput(audio_encoder) -class AudioEncoderEncode: +class AudioEncoderEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio_encoder": ("AUDIO_ENCODER",), - "audio": ("AUDIO",), - }} - RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) - FUNCTION = "encode" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderEncode", + category="conditioning", + inputs=[ + io.AudioEncoder.Input("audio_encoder"), + io.Audio.Input("audio"), + ], + outputs=[io.AudioEncoderOutput.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, audio_encoder, audio): + @classmethod + def execute(cls, audio_encoder, audio) -> io.NodeOutput: output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) - return (output,) + return io.NodeOutput(output) -NODE_CLASS_MAPPINGS = { - "AudioEncoderLoader": AudioEncoderLoader, - "AudioEncoderEncode": AudioEncoderEncode, -} +class AudioEncoder(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AudioEncoderLoader, + AudioEncoderEncode, + ] + + +async def comfy_entrypoint() -> AudioEncoder: + return AudioEncoder() From 7eb7160db487feb891ceabdf985b09f9a8091869 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:16:59 +0300 Subject: [PATCH 294/325] convert nodes_gits.py to V3 schema (#9949) --- comfy_extras/nodes_gits.py | 49 ++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 47b1dd049..25367560a 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -1,6 +1,8 @@ # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def loglinear_interp(t_steps, num_steps): """ @@ -333,25 +335,28 @@ NOISE_LEVELS = { ], } -class GITSScheduler: +class GITSScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), - "steps": ("INT", {"default": 10, "min": 2, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="GITSScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05), + io.Int.Input("steps", default=10, min=2, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, coeff, steps, denoise): + @classmethod + def execute(cls, coeff, steps, denoise): total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) if steps <= 20: @@ -362,8 +367,16 @@ class GITSScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "GITSScheduler": GITSScheduler, -} + +class GITSSchedulerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + GITSScheduler, + ] + + +async def comfy_entrypoint() -> GITSSchedulerExtension: + return GITSSchedulerExtension() From e0210ce0a7140e0c61bce7fdb964b5e5e8d31619 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:17:33 +0300 Subject: [PATCH 295/325] convert nodes_differential_diffusion.py to V3 schema (#10056) --- comfy_extras/nodes_differential_diffusion.py | 69 ++++++++++++-------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 255ac420d..6dfdf466c 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -1,34 +1,41 @@ # code adapted from https://github.com/exx8/differential-diffusion +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io -class DifferentialDiffusion(): + +class DifferentialDiffusion(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", ), - }, - "optional": { - "strength": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - }), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "apply" - CATEGORY = "_for_testing" - INIT = False + def define_schema(cls): + return io.Schema( + node_id="DifferentialDiffusion", + display_name="Differential Diffusion", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + optional=True, + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - def apply(self, model, strength=1.0): + @classmethod + def execute(cls, model, strength=1.0) -> io.NodeOutput: model = model.clone() - model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) - return (model, ) + model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength)) + return io.NodeOutput(model) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): + @classmethod + def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -53,9 +60,13 @@ class DifferentialDiffusion(): return binary_mask -NODE_CLASS_MAPPINGS = { - "DifferentialDiffusion": DifferentialDiffusion, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "DifferentialDiffusion": "Differential Diffusion", -} +class DifferentialDiffusionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + DifferentialDiffusion, + ] + + +async def comfy_entrypoint() -> DifferentialDiffusionExtension: + return DifferentialDiffusionExtension() From 3af1881455fb0c44c3030b2d61b79302933386d2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:18:04 +0300 Subject: [PATCH 296/325] convert nodes_optimalsteps.py to V3 schema (#10074) --- comfy_extras/nodes_optimalsteps.py | 52 +++++++++++++++++++----------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index e7c851ca2..73f0104d8 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -1,9 +1,12 @@ # from https://github.com/bebebe666/OptimalSteps - import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + def loglinear_interp(t_steps, num_steps): """ Performs log-linear interpolation of a given array of decreasing numbers. @@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0 "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], } -class OptimalStepsScheduler: +class OptimalStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["FLUX", "Wan", "Chroma"], ), - "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="OptimalStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]), + io.Int.Input("steps", default=20, min=3, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model_type, steps, denoise): + @classmethod + def execute(cls, model_type, steps, denoise) ->io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -50,8 +56,16 @@ class OptimalStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "OptimalStepsScheduler": OptimalStepsScheduler, -} + +class OptimalStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + OptimalStepsScheduler, + ] + + +async def comfy_entrypoint() -> OptimalStepsExtension: + return OptimalStepsExtension() From 11bab7be76d0bfdb326e8aea53cdfebd99b42cc5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:18:49 +0300 Subject: [PATCH 297/325] convert nodes_pag.py to V3 schema (#10080) --- comfy_extras/nodes_pag.py | 49 +++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py index eb28196f4..79fea5f0c 100644 --- a/comfy_extras/nodes_pag.py +++ b/comfy_extras/nodes_pag.py @@ -3,25 +3,30 @@ #My modified one here is more basic but has less chances of breaking with ComfyUI updates. +from typing_extensions import override + import comfy.model_patcher import comfy.samplers +from comfy_api.latest import ComfyExtension, io -class PerturbedAttentionGuidance: + +class PerturbedAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerturbedAttentionGuidance", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "model_patches/unet" - - def patch(self, model, scale): + @classmethod + def execute(cls, model, scale) -> io.NodeOutput: unet_block = "middle" unet_block_id = 0 m = model.clone() @@ -49,8 +54,16 @@ class PerturbedAttentionGuidance: m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "PerturbedAttentionGuidance": PerturbedAttentionGuidance, -} + +class PAGExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerturbedAttentionGuidance, + ] + + +async def comfy_entrypoint() -> PAGExtension: + return PAGExtension() From d9c0a4053d955c7fd3400be07001bc4e774591e1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:19:56 +0300 Subject: [PATCH 298/325] convert nodes_lt.py to V3 schema (#10084) --- comfy_extras/nodes_lt.py | 412 ++++++++++++++++++++++----------------- 1 file changed, 228 insertions(+), 184 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index f82337a67..b51d15804 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,4 +1,3 @@ -import io import nodes import node_helpers import torch @@ -8,46 +7,60 @@ import comfy.utils import math import numpy as np import av +from io import BytesIO +from typing_extensions import override from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +from comfy_api.latest import ComfyExtension, io -class EmptyLTXVLatentVideo: +class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyLTXVLatentVideo", + category="latent/video/ltxv", + inputs=[ + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video/ltxv" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) -class LTXVImgToVideo: +class LTXVImgToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "image": ("IMAGE",), - "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), - }} + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): + @classmethod + def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput: pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) @@ -62,7 +75,7 @@ class LTXVImgToVideo: ) conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength - return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) + return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) def conditioning_get_any_value(conditioning, key, default=None): @@ -93,35 +106,46 @@ def get_keyframe_idxs(cond): num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] return keyframe_idxs, num_keyframes -class LTXVAddGuide: +class LTXVAddGuide(io.ComfyNode): + NUM_PREFIX_FRAMES = 2 + PATCHIFIER = SymmetricPatchifier(1) + @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "latent": ("LATENT",), - "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." - "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), - "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, - "tooltip": "Frame index to start the conditioning at. For single-frame images or " - "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " - "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " - "the nearest multiple of 8. Negative values are counted from the end of the video."}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVAddGuide", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Latent.Input("latent"), + io.Image.Input( + "image", + tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", + ), + io.Int.Input( + "frame_idx", + default=0, + min=-9999, + max=9999, + tooltip="Frame index to start the conditioning at. " + "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. " + "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " + "down to the nearest multiple of 8. Negative values are counted from the end of the video.", + ), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def __init__(self): - self._num_prefix_frames = 2 - self._patchifier = SymmetricPatchifier(1) - - def encode(self, vae, latent_width, latent_height, images, scale_factors): + @classmethod + def encode(cls, vae, latent_width, latent_height, images, scale_factors): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) @@ -129,7 +153,8 @@ class LTXVAddGuide: t = vae.encode(encode_pixels) return encode_pixels, t - def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): + @classmethod + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors _, num_keyframes = get_keyframe_idxs(cond) latent_count = latent_length - num_keyframes @@ -141,9 +166,10 @@ class LTXVAddGuide: return frame_idx, latent_idx - def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): + @classmethod + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): keyframe_idxs, _ = get_keyframe_idxs(cond) - _, latent_coords = self._patchifier.patchify(guiding_latent) + _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx if keyframe_idxs is None: @@ -152,8 +178,9 @@ class LTXVAddGuide: keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) - def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = self.get_latent_index( + @classmethod + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + _, latent_idx = cls.get_latent_index( cond=positive, latent_length=latent_image.shape[2], guide_length=guiding_latent.shape[2], @@ -162,8 +189,8 @@ class LTXVAddGuide: ) noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 - positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), @@ -176,7 +203,8 @@ class LTXVAddGuide: noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask - def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): + @classmethod + def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength): cond_length = guiding_latent.shape[2] assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." @@ -195,20 +223,21 @@ class LTXVAddGuide: return latent_image, noise_mask - def generate(self, positive, negative, vae, latent, image, frame_idx, strength): + @classmethod + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape - image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) - frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) + num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = self.append_keyframe( + positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, @@ -223,9 +252,9 @@ class LTXVAddGuide: t = t[:, :, num_prefix_frames:] if t.shape[2] == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - latent_image, noise_mask = self.replace_latent_frames( + latent_image, noise_mask = cls.replace_latent_frames( latent_image, noise_mask, t, @@ -233,34 +262,35 @@ class LTXVAddGuide: strength, ) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) -class LTXVCropGuides: +class LTXVCropGuides(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent": ("LATENT",), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVCropGuides", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "crop" - - def __init__(self): - self._patchifier = SymmetricPatchifier(1) - - def crop(self, positive, negative, latent): + @classmethod + def execute(cls, positive, negative, latent) -> io.NodeOutput: latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) _, num_keyframes = get_keyframe_idxs(positive) if num_keyframes == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] @@ -268,44 +298,52 @@ class LTXVCropGuides: positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) -class LTXVConditioning: +class LTXVConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "append" + def define_schema(cls): + return io.Schema( + node_id="LTXVConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) - CATEGORY = "conditioning/video_models" - - def append(self, positive, negative, frame_rate): + @classmethod + def execute(cls, positive, negative, frame_rate) -> io.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) - return (positive, negative) + return io.NodeOutput(positive, negative) -class ModelSamplingLTXV: +class ModelSamplingLTXV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingLTXV", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "advanced/model" - - def patch(self, model, max_shift, base_shift, latent=None): + @classmethod + def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput: m = model.clone() if latent is None: @@ -329,37 +367,41 @@ class ModelSamplingLTXV: model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) - return (m, ) + return io.NodeOutput(m) -class LTXVScheduler: +class LTXVScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - "stretch": ("BOOLEAN", { - "default": True, - "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." - }), - "terminal": ( - "FLOAT", - { - "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, - "tooltip": "The terminal value of the sigmas after stretching." - }, - ), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + id="stretch", + default=True, + tooltip="Stretch the sigmas to be in the range [terminal, 1].", + ), + io.Float.Input( + id="terminal", + default=0.1, + min=0.0, + max=0.99, + step=0.01, + tooltip="The terminal value of the sigmas after stretching.", + ), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): + @classmethod + def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput: if latent is None: tokens = 4096 else: @@ -389,7 +431,7 @@ class LTXVScheduler: stretched = 1.0 - (one_minus_z / scale_factor) sigmas[non_zero_mask] = stretched - return (sigmas,) + return io.NodeOutput(sigmas) def encode_single_frame(output_file, image_array: np.ndarray, crf): container = av.open(output_file, "w", format="mp4") @@ -423,52 +465,54 @@ def preprocess(image: torch.Tensor, crf=29): return image image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() - with io.BytesIO() as output_file: + with BytesIO() as output_file: encode_single_frame(output_file, image_array, crf) video_bytes = output_file.getvalue() - with io.BytesIO(video_bytes) as video_file: + with BytesIO(video_bytes) as video_file: image_array = decode_single_frame(video_file) tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 return tensor -class LTXVPreprocess: +class LTXVPreprocess(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "img_compression": ( - "INT", - { - "default": 35, - "min": 0, - "max": 100, - "tooltip": "Amount of compression to apply on image.", - }, + def define_schema(cls): + return io.Schema( + node_id="LTXVPreprocess", + category="image", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." ), - } - } + ], + outputs=[ + io.Image.Output(display_name="output_image"), + ], + ) - FUNCTION = "preprocess" - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("output_image",) - CATEGORY = "image" - - def preprocess(self, image, img_compression): + @classmethod + def execute(cls, image, img_compression) -> io.NodeOutput: output_images = [] for i in range(image.shape[0]): output_images.append(preprocess(image[i], img_compression)) - return (torch.stack(output_images),) + return io.NodeOutput(torch.stack(output_images)) -NODE_CLASS_MAPPINGS = { - "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, - "LTXVImgToVideo": LTXVImgToVideo, - "ModelSamplingLTXV": ModelSamplingLTXV, - "LTXVConditioning": LTXVConditioning, - "LTXVScheduler": LTXVScheduler, - "LTXVAddGuide": LTXVAddGuide, - "LTXVPreprocess": LTXVPreprocess, - "LTXVCropGuides": LTXVCropGuides, -} +class LtxvExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyLTXVLatentVideo, + LTXVImgToVideo, + ModelSamplingLTXV, + LTXVConditioning, + LTXVScheduler, + LTXVAddGuide, + LTXVPreprocess, + LTXVCropGuides, + ] + + +async def comfy_entrypoint() -> LtxvExtension: + return LtxvExtension() From e4f99b479a19730bea890567129f4032b4dd4787 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:20:30 +0300 Subject: [PATCH 299/325] convert nodes_ip2p.pt to V3 schema (#10097) --- comfy_extras/nodes_ip2p.py | 54 +++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py index c2e70a84c..78f29915d 100644 --- a/comfy_extras/nodes_ip2p.py +++ b/comfy_extras/nodes_ip2p.py @@ -1,21 +1,30 @@ import torch -class InstructPixToPixConditioning: +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class InstructPixToPixConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "pixels": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="InstructPixToPixConditioning", + category="conditioning/instructpix2pix", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("pixels"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/instructpix2pix" - - def encode(self, positive, negative, pixels, vae): + @classmethod + def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput: x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 @@ -38,8 +47,17 @@ class InstructPixToPixConditioning: n = [t[0], d] c.append(n) out.append(c) - return (out[0], out[1], out_latent) + return io.NodeOutput(out[0], out[1], out_latent) + + +class InstructPix2PixExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + InstructPixToPixConditioning, + ] + + +async def comfy_entrypoint() -> InstructPix2PixExtension: + return InstructPix2PixExtension() -NODE_CLASS_MAPPINGS = { - "InstructPixToPixConditioning": InstructPixToPixConditioning, -} From a6f83a4a1a70d720c16d66feb5d87fee5998acdf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:19:13 -0700 Subject: [PATCH 300/325] Support the new hunyuan vae. (#10150) --- comfy/ldm/hunyuan_video/vae_refiner.py | 112 ++++++++++++++++--------- comfy/sd.py | 70 ++++++++++------ 2 files changed, 116 insertions(+), 66 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c6f742710..c2a0b507d 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize import comfy.ops import comfy.ldm.models.autoencoder ops = comfy.ops.disable_weight_init @@ -17,11 +17,12 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * self.gamma class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True): + def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 - self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tds = tds self.gs = fct * ic // oc @@ -30,7 +31,7 @@ class DnSmpl(nn.Module): r1 = 2 if self.tds else 1 h = self.conv(x) - if self.tds: + if self.tds and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -66,6 +67,7 @@ class DnSmpl(nn.Module): sc = torch.cat([xf, xn], dim=2) else: b, c, frms, ht, wd = h.shape + nf = frms // r1 h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) @@ -83,10 +85,11 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True): + def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 - self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tus = tus self.rp = fct * oc // ic @@ -95,7 +98,7 @@ class UpSmpl(nn.Module): r1 = 2 if self.tus else 1 h = self.conv(x) - if self.tus: + if self.tus and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -148,43 +151,56 @@ class UpSmpl(nn.Module): class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): super().__init__() self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks - self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + self.ffactor_temporal = ffactor_temporal + + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + + self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1) self.down = nn.ModuleList() ch = block_out_channels[0] depth = (ffactor_spatial >> 1).bit_length() - depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch - stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): + if not self.refiner_vae and x.shape[2] == 1: + x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) + x = self.conv_in(x) for stage in self.down: @@ -200,31 +216,42 @@ class Encoder(nn.Module): skip = x.view(b, c // grp, grp, t, h, w).mean(2) out = self.conv_out(F.silu(self.norm_out(x))) + skip - out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() + if self.refiner_vae: + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): super().__init__() block_out_channels = block_out_channels[::-1] self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + ch = block_out_channels[0] - self.conv_in = VideoConv3d(z_channels, ch, 3) + self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -235,25 +262,26 @@ class Decoder(nn.Module): stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch - stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.up.append(stage) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, out_channels, 3) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] + if self.refiner_vae: + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) @@ -264,4 +292,10 @@ class Decoder(nn.Module): if hasattr(stage, 'upsample'): x = stage.upsample(x) - return self.conv_out(F.silu(self.norm_out(x))) + out = self.conv_out(F.silu(self.norm_out(x))) + + if not self.refiner_vae: + if z.shape[-3] == 1: + out = out[:, :, -1:] + + return out diff --git a/comfy/sd.py b/comfy/sd.py index 2df340739..873ad20f2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -332,35 +332,51 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: - ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - self.downscale_ratio = 32 - self.upscale_ratio = 32 - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) - - self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif "decoder.conv_in.weight" in sd: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - - if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE - ddconfig['ch_mult'] = [1, 2, 4] - self.downscale_ratio = 4 - self.upscale_ratio = 4 - - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - else: + if sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif sd['decoder.conv_in.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.not_video = True + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'post_quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: self.first_stage_model = AudioOobleckVAE() self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) From bb32d4ec3141333df26fcdaee0c3c08e41b7b249 Mon Sep 17 00:00:00 2001 From: Koratahiu Date: Thu, 2 Oct 2025 00:59:07 +0300 Subject: [PATCH 301/325] feat: Add Epsilon Scaling node for exposure bias correction (#10132) --- comfy_extras/nodes_eps.py | 60 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 61 insertions(+) create mode 100644 comfy_extras/nodes_eps.py diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py new file mode 100644 index 000000000..c8818f096 --- /dev/null +++ b/comfy_extras/nodes_eps.py @@ -0,0 +1,60 @@ +class EpsilonScaling: + """ + Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' + (https://arxiv.org/abs/2308.15321v6). + + This method mitigates exposure bias by scaling the predicted noise during sampling, + which can significantly improve sample quality. This implementation uses the "uniform schedule" + recommended by the paper for its practicality and effectiveness. + """ + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scaling_factor": ("FLOAT", { + "default": 1.005, + "min": 0.5, + "max": 1.5, + "step": 0.001, + "display": "number" + }), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, scaling_factor): + # Prevent division by zero, though the UI's min value should prevent this. + if scaling_factor == 0: + scaling_factor = 1e-9 + + def epsilon_scaling_function(args): + """ + This function is applied after the CFG guidance has been calculated. + It recalculates the denoised latent by scaling the predicted noise. + """ + denoised = args["denoised"] + x = args["input"] + + noise_pred = x - denoised + + scaled_noise_pred = noise_pred / scaling_factor + + new_denoised = x - scaled_noise_pred + + return new_denoised + + # Clone the model patcher to avoid modifying the original model in place + model_clone = model.clone() + + model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) + + return (model_clone,) + +NODE_CLASS_MAPPINGS = { + "Epsilon Scaling": EpsilonScaling +} diff --git a/nodes.py b/nodes.py index 1a6784b68..88d712993 100644 --- a/nodes.py +++ b/nodes.py @@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes(): "nodes_gits.py", "nodes_controlnet.py", "nodes_hunyuan.py", + "nodes_eps.py", "nodes_flux.py", "nodes_lora_extract.py", "nodes_torch_compile.py", From 911331c06c16aa80633c5438c58edb32dbfdff50 Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Thu, 2 Oct 2025 08:40:28 +1000 Subject: [PATCH 302/325] sd: fix VAE tiled fallback VRAM leak (#10139) When the VAE catches this VRAM OOM, it launches the fallback logic straight from the exception context. Python however refs the entire call stack that caused the exception including any local variables for the sake of exception report and debugging. In the case of tensors, this can hold on the references to GBs of VRAM and inhibit the VRAM allocated from freeing them. So dump the except context completely before going back to the VAE via the tiler by getting out of the except block with nothing but a flag. The greately increases the reliability of the tiler fallback, especially on low VRAM cards, as with the bug, if the leak randomly leaked more than the headroom needed for a single tile, the tiler would fallback would OOM and fail the flow. --- comfy/sd.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 873ad20f2..be225ad03 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -652,6 +652,7 @@ class VAE: def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None + do_tile = False try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -667,6 +668,13 @@ class VAE: pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -713,6 +721,7 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) + do_tile = False if self.latent_dim == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) @@ -734,6 +743,13 @@ class VAE: except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: if self.latent_dim == 3: tile = 256 overlap = tile // 4 From 4965c0e2acf39d84e82cb63dd6cc4400299d0a61 Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Thu, 2 Oct 2025 08:42:16 +1000 Subject: [PATCH 303/325] WAN: Fix cache VRAM leak on error (#10141) If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack. --- comfy/ldm/wan/vae.py | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 791596938..ccbb25822 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -468,55 +468,46 @@ class WanVAE(nn.Module): attn_scales, self.temperal_upsample, dropout) def encode(self, x): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): - self._enc_conv_idx = [0] + conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - self.clear_cache() return mu def decode(self, z): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): - self._conv_idx = [0] + conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) - self.clear_cache() return out - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - #cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num From 0e9d1724be327c79ba86159d868f0b57adb8c384 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 1 Oct 2025 21:33:05 -0700 Subject: [PATCH 304/325] Add a .bat to the AMD portable to disable smart memory. (#10153) --- .ci/windows_amd_base_files/README_VERY_IMPORTANT.txt | 5 ++++- .../run_amd_gpu_disable_smart_memory.bat | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100755 .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 570ac3398..96a500be2 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -3,10 +3,13 @@ https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOW HOW TO RUN: -if you have a AMD gpu: +If you have a AMD gpu: run_amd_gpu.bat +If you have memory issues you can try disabling the smart memory management by running comfyui with: + +run_amd_gpu_disable_smart_memory.bat IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat new file mode 100755 index 000000000..cece0aeb2 --- /dev/null +++ b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory +pause From 8f4ee9984c0c3864290e4fea81cfea2ba281717d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 2 Oct 2025 23:53:00 +0300 Subject: [PATCH 305/325] convert nodes_morphology.py to V3 schema (#10159) --- comfy_extras/nodes_morphology.py | 116 +++++++++++++++++++------------ 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 075b26c40..67377e1bc 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -1,24 +1,34 @@ import torch import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat import kornia.color -class Morphology: +class Morphology(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), - "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), - }} + def define_schema(cls): + return io.Schema( + node_id="Morphology", + display_name="ImageMorphology", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Combo.Input( + "operation", + options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"], + ), + io.Int.Input("kernel_size", default=3, min=3, max=999, step=1), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "process" - - CATEGORY = "image/postprocessing" - - def process(self, image, operation, kernel_size): + @classmethod + def execute(cls, image, operation, kernel_size) -> io.NodeOutput: device = comfy.model_management.get_torch_device() kernel = torch.ones(kernel_size, kernel_size, device=device) image_k = image.to(device).movedim(-1, 1) @@ -39,49 +49,63 @@ class Morphology: else: raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -class ImageRGBToYUV: +class ImageRGBToYUV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageRGBToYUV", + category="image/batch", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(display_name="Y"), + io.Image.Output(display_name="U"), + io.Image.Output(display_name="V"), + ], + ) - RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") - RETURN_NAMES = ("Y", "U", "V") - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, image): + @classmethod + def execute(cls, image) -> io.NodeOutput: out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) - return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) -class ImageYUVToRGB: +class ImageYUVToRGB(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"Y": ("IMAGE",), - "U": ("IMAGE",), - "V": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageYUVToRGB", + category="image/batch", + inputs=[ + io.Image.Input("Y"), + io.Image.Input("U"), + io.Image.Input("V"), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, Y, U, V): + @classmethod + def execute(cls, Y, U, V) -> io.NodeOutput: image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) - return (out,) + return io.NodeOutput(out) -NODE_CLASS_MAPPINGS = { - "Morphology": Morphology, - "ImageRGBToYUV": ImageRGBToYUV, - "ImageYUVToRGB": ImageYUVToRGB, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Morphology": "ImageMorphology", -} +class MorphologyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Morphology, + ImageRGBToYUV, + ImageYUVToRGB, + ] + + +async def comfy_entrypoint() -> MorphologyExtension: + return MorphologyExtension() + From f6e3e9a456127a7e539929f42ea6cac838197879 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 00:50:31 +0300 Subject: [PATCH 306/325] fix(api-nodes): made logging path to be smaller (#10156) --- comfy_api_nodes/apis/client.py | 5 +- comfy_api_nodes/apis/request_logger.py | 72 ++++++++++++++++++++------ 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 0aed906fb..18a694675 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -95,6 +95,7 @@ import aiohttp import asyncio import logging import io +import os import socket from aiohttp.client_exceptions import ClientError, ClientResponseError from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple @@ -499,7 +500,9 @@ class ApiClient: else: raise ValueError("File must be BytesIO or str path") - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" + parsed = urlparse(upload_url) + basename = os.path.basename(parsed.path) or parsed.netloc or "upload" + operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" request_logger.log_request_response( operation_id=operation_id, request_method="PUT", diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py index 42901e141..2e0ca5380 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/apis/request_logger.py @@ -4,16 +4,18 @@ import os import datetime import json import logging +import re +import hashlib +from typing import Any + import folder_paths # Get the logger instance logger = logging.getLogger(__name__) + def get_log_directory(): - """ - Ensures the API log directory exists within ComfyUI's temp directory - and returns its path. - """ + """Ensures the API log directory exists within ComfyUI's temp directory and returns its path.""" base_temp_dir = folder_paths.get_temp_directory() log_dir = os.path.join(base_temp_dir, "api_logs") try: @@ -24,42 +26,77 @@ def get_log_directory(): return base_temp_dir return log_dir -def _format_data_for_logging(data): + +def _sanitize_filename_component(name: str) -> str: + if not name: + return "log" + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore + sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed + if not sanitized: + sanitized = "log" + return sanitized + + +def _short_hash(*parts: str, length: int = 10) -> str: + return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length] + + +def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str: + """Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id + h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL + + # Compute how much room we have for the slug given the directory length + # Keep total path length reasonably below ~260 on Windows. + max_total_path = 240 + prefix = f"{timestamp}_" + suffix = f"_{h}.log" + if not slug: + slug = "op" + max_filename_len = max(60, max_total_path - len(log_dir) - 1) + max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix)) + if len(slug) > max_slug_len: + slug = slug[:max_slug_len].rstrip(" ._-") + return os.path.join(log_dir, f"{prefix}{slug}{suffix}") + + +def _format_data_for_logging(data: Any) -> str: """Helper to format data (dict, str, bytes) for logging.""" if isinstance(data, bytes): try: - return data.decode('utf-8') # Try to decode as text + return data.decode("utf-8") # Try to decode as text except UnicodeDecodeError: return f"[Binary data of length {len(data)} bytes]" elif isinstance(data, (dict, list)): try: return json.dumps(data, indent=2, ensure_ascii=False) except TypeError: - return str(data) # Fallback for non-serializable objects + return str(data) # Fallback for non-serializable objects return str(data) + def log_request_response( operation_id: str, request_method: str, request_url: str, request_headers: dict | None = None, request_params: dict | None = None, - request_data: any = None, + request_data: Any = None, response_status_code: int | None = None, response_headers: dict | None = None, - response_content: any = None, - error_message: str | None = None + response_content: Any = None, + error_message: str | None = None, ): """ Logs API request and response details to a file in the temp/api_logs directory. + Filenames are sanitized and length-limited for cross-platform safety. + If we still fail to write, we fall back to appending into api.log. """ log_dir = get_log_directory() - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") - filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log" - filepath = os.path.join(log_dir, filename) - - log_content = [] + filepath = _build_log_filepath(log_dir, operation_id, request_url) + log_content: list[str] = [] log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") log_content.append(f"Operation ID: {operation_id}") log_content.append("-" * 30 + " REQUEST " + "-" * 30) @@ -69,7 +106,7 @@ def log_request_response( log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") if request_params: log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") - if request_data: + if request_data is not None: log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) @@ -77,7 +114,7 @@ def log_request_response( log_content.append(f"Status Code: {response_status_code}") if response_headers: log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") - if response_content: + if response_content is not None: log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") if error_message: log_content.append(f"Error:\n{error_message}") @@ -89,6 +126,7 @@ def log_request_response( except Exception as e: logger.error(f"Error writing API log to {filepath}: {e}") + if __name__ == '__main__': # Example usage (for testing the logger directly) logger.setLevel(logging.DEBUG) From e9364ee279f65d0546fea1796c3cd2e0b7e1965f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 2 Oct 2025 14:57:15 -0700 Subject: [PATCH 307/325] Turn on TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL by default. (#10168) --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 70696fcc3..35857dba8 100644 --- a/main.py +++ b/main.py @@ -115,6 +115,7 @@ if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' if args.default_device is not None: default_dev = args.default_device devices = list(range(32)) From 1395bce9f707e52ec613eeaa87ea690518cfe0a8 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:20:29 +0300 Subject: [PATCH 308/325] update example_node to use V3 schema (#9723) --- custom_nodes/example_node.py.example | 161 +++++++++++---------------- 1 file changed, 68 insertions(+), 93 deletions(-) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 29ab2aa72..779c35787 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -1,96 +1,70 @@ -class Example: +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class Example(io.ComfyNode): """ - A example node + An example node Class methods ------------- - INPUT_TYPES (dict): - Tell the main program input parameters of nodes. - IS_CHANGED: + define_schema (io.Schema): + Tell the main program the metadata, input, output parameters of nodes. + fingerprint_inputs: optional method to control when the node is re executed. + check_lazy_status: + optional method to control list of input names that need to be evaluated. - Attributes - ---------- - RETURN_TYPES (`tuple`): - The type of each element in the output tuple. - RETURN_NAMES (`tuple`): - Optional: The name of each output in the output tuple. - FUNCTION (`str`): - The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() - OUTPUT_NODE ([`bool`]): - If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. - The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. - Assumed to be False if not present. - CATEGORY (`str`): - The category the node should appear in the UI. - DEPRECATED (`bool`): - Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain - functional in existing workflows that use them. - EXPERIMENTAL (`bool`): - Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to - significant changes or removal in future versions. Use with caution in production workflows. - execute(s) -> tuple || None: - The entry point method. The name of this method must be the same as the value of property `FUNCTION`. - For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. """ - def __init__(self): - pass @classmethod - def INPUT_TYPES(s): + def define_schema(cls) -> io.Schema: """ - Return a dictionary which contains config for all input fields. - Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. - The type can be a list for selection. - - Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - - Value input_fields (`dict`): Contains input fields config: - * Key field_name (`string`): Name of a entry-point method's argument - * Value field_config (`tuple`): - + First value is a string indicate the type of field or a list for selection. - + Second value is a config for type "INT", "STRING" or "FLOAT". + Return a schema which contains all information about the node. + Some types: "Model", "Vae", "Clip", "Conditioning", "Latent", "Image", "Int", "String", "Float", "Combo". + For outputs the "io.Model.Output" should be used, for inputs the "io.Model.Input" can be used. + The type can be a "Combo" - this will be a list for selection. """ - return { - "required": { - "image": ("IMAGE",), - "int_field": ("INT", { - "default": 0, - "min": 0, #Minimum value - "max": 4096, #Maximum value - "step": 64, #Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - "lazy": True # Will only be evaluated if check_lazy_status requires it - }), - "float_field": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 10.0, - "step": 0.01, - "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number", - "lazy": True - }), - "print_to_screen": (["enable", "disable"],), - "string_field": ("STRING", { - "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!", - "lazy": True - }), - }, - } + return io.Schema( + node_id="Example", + display_name="Example Node", + category="Example", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + "int_field", + min=0, + max=4096, + step=64, # Slider's step + display_mode=io.NumberDisplay.number, # Cosmetic only: display as "number" or "slider" + lazy=True, # Will only be evaluated if check_lazy_status requires it + ), + io.Float.Input( + "float_field", + default=1.0, + min=0.0, + max=10.0, + step=0.01, + round=0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + display_mode=io.NumberDisplay.number, + lazy=True, + ), + io.Combo.Input("print_to_screen", options=["enable", "disable"]), + io.String.Input( + "string_field", + multiline=False, # True if you want the field to look like the one on the ClipTextEncode node + default="Hello world!", + lazy=True, + ) + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - #RETURN_NAMES = ("image_output_name",) - - FUNCTION = "test" - - #OUTPUT_NODE = False - - CATEGORY = "Example" - - def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def check_lazy_status(cls, image, string_field, int_field, float_field, print_to_screen): """ Return a list of input names that need to be evaluated. @@ -107,7 +81,8 @@ class Example: else: return [] - def test(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def execute(cls, image, string_field, int_field, float_field, print_to_screen) -> io.NodeOutput: if print_to_screen == "enable": print(f"""Your input contains: string_field aka input text: {string_field} @@ -116,7 +91,7 @@ class Example: """) #do some processing on the image, in this example I just invert it image = 1.0 - image - return (image,) + return io.NodeOutput(image) """ The node will always be re executed if any of the inputs change but @@ -127,7 +102,7 @@ class Example: changes between executions the LoadImage node is executed again. """ #@classmethod - #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + #def fingerprint_inputs(s, image, string_field, int_field, float_field, print_to_screen): # return "" # Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension @@ -143,13 +118,13 @@ async def get_hello(request): return web.json_response("hello") -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Example": Example -} +class ExampleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Example, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Example": "Example Node" -} + +async def comfy_entrypoint() -> ExampleExtension: # ComfyUI calls this to load your extension and its nodes. + return ExampleExtension() From 4ffea0e864275301329ddb5ecc3fbc7211d7a802 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 02:14:28 +0300 Subject: [PATCH 309/325] feat(linter, api-nodes): add pylint for comfy_api_nodes folder (#10157) --- .github/workflows/ruff.yml | 25 ++++++++++++++ comfy_api_nodes/apis/__init__.py | 1 + comfy_api_nodes/apis/client.py | 2 +- comfy_api_nodes/apis/rodin_api.py | 4 --- pyproject.toml | 54 +++++++++++++++++++++++++++++++ 5 files changed, 81 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 4c1a02594..b24d86a6b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -21,3 +21,28 @@ jobs: - name: Run Ruff run: ruff check . + + pylint: + name: Run Pylint + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + + - name: Install Pylint + run: pip install pylint + + - name: Run Pylint + run: pylint comfy_api_nodes diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 78a23db30..98f9e540d 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -2,6 +2,7 @@ # filename: filtered-openapi.yaml # timestamp: 2025-07-30T08:54:00+00:00 +# pylint: disable from __future__ import annotations from datetime import date, datetime diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 18a694675..79de3c262 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -535,7 +535,7 @@ class ApiClient: request_method="PUT", request_url=upload_url, response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_headers=dict(e.headers) if hasattr(e, "headers") else None, response_content=None, error_message=f"{type(e).__name__}: {str(e)}", ) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py index 02cf42c29..fc26a6e73 100644 --- a/comfy_api_nodes/apis/rodin_api.py +++ b/comfy_api_nodes/apis/rodin_api.py @@ -52,7 +52,3 @@ class RodinResourceItem(BaseModel): class Rodin3DDownloadResponse(BaseModel): list: List[RodinResourceItem] = Field(..., description="Source List") - - - - diff --git a/pyproject.toml b/pyproject.toml index d0a76c6d0..598af4157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,3 +22,57 @@ lint.select = [ "F", ] exclude = ["*.ipynb", "**/generated/*.pyi"] + +[tool.pylint] +master.py-version = "3.9" +master.extension-pkg-allow-list = [ + "pydantic", +] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "line-too-long", + "too-few-public-methods", + "too-many-public-methods", + "too-many-instance-attributes", + "too-many-positional-arguments", + "broad-exception-raised", + "too-many-lines", + "invalid-name", + "unused-argument", + "broad-exception-caught", + "consider-using-with", + "fixme", + "too-many-statements", + "too-many-branches", + "too-many-locals", + "too-many-arguments", + "duplicate-code", + "abstract-method", + "superfluous-parens", + "arguments-differ", + "redefined-builtin", + "unnecessary-lambda", + "dangerous-default-value", + # next warnings should be fixed in future + "bad-classmethod-argument", # Class method should have 'cls' as first argument + "wrong-import-order", # Standard imports should be placed before third party imports + "logging-fstring-interpolation", # Use lazy % formatting in logging functions + "ungrouped-imports", + "unnecessary-pass", + "unidiomatic-typecheck", + "unnecessary-lambda-assignment", + "bad-indentation", + "no-else-return", + "no-else-raise", + "invalid-overridden-method", + "unused-variable", + "pointless-string-statement", + "inconsistent-return-statements", + "import-outside-toplevel", + "reimported", + "redefined-outer-name", +] From ed3ca78e080d697b6cf29497c07e14ee9c27a3ac Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:26:34 +0300 Subject: [PATCH 310/325] feat(api-nodes): add kling-2-5-turbo to txt2video and img2video nodes (#10155) --- comfy_api_nodes/apis/__init__.py | 2 ++ comfy_api_nodes/nodes_kling.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 98f9e540d..ee2aa1ce6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1321,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum): kling_v1 = 'kling-v1' kling_v1_6 = 'kling-v1-6' kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoGenAspectRatio(str, Enum): @@ -1355,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum): kling_v2_master = 'kling-v2-master' kling_v2_1 = 'kling-v2-1' kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoResult(BaseModel): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5f55b2cc9..d8646f106 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -423,6 +423,8 @@ class KlingTextToVideoNode(KlingNodeBase): "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), } @classmethod From 8a293372ecdea0ff8647921eaf3bb10c3d992abf Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:40:27 +0300 Subject: [PATCH 311/325] fix(api-nodes): reimport of base64 in Gemini node (#10181) --- comfy_api_nodes/nodes_gemini.py | 1 - pyproject.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index baa379b75..151cb4044 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -490,7 +490,6 @@ class GeminiInputFiles(ComfyNodeABC): # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() - import base64 base64_str = base64.b64encode(file_content).decode("utf-8") return GeminiPart( diff --git a/pyproject.toml b/pyproject.toml index 598af4157..7952f7f37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,5 @@ messages_control.disable = [ "pointless-string-statement", "inconsistent-return-statements", "import-outside-toplevel", - "reimported", "redefined-outer-name", ] From c2c5a7d5f80579bb44c11de0ce6eff94d1c111b9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:41:06 +0300 Subject: [PATCH 312/325] fix(api-nodes): bad indentation in Recraft API node function (#10175) --- comfy_api_nodes/nodes_recraft.py | 78 ++++++++++++++++---------------- pyproject.toml | 1 - 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index c8516b368..a006104b7 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -38,48 +38,48 @@ from PIL import UnidentifiedImageError async def handle_recraft_file_request( - image: torch.Tensor, - path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, - request=None, - auth_kwargs: dict[str,str] = None, - ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() + image: torch.Tensor, + path: str, + mask: torch.Tensor=None, + total_pixels=4096*4096, + timeout=1024, + request=None, + auth_kwargs: dict[str,str] = None, +) -> list[BytesIO]: + """ + Handle sending common Recraft file-only request to get back file bytes. + """ + if request is None: + request = EmptyRequest() - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } - if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files = { + 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() + } + if mask is not None: + files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, - files=files, - content_type="multipart/form-data", - auth_kwargs=auth_kwargs, - multipart_parser=recraft_multipart_parser, - ) - response: RecraftImageGenerationResponse = await operation.execute() - all_bytesio = [] - if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) - else: - for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=type(request), + response_model=RecraftImageGenerationResponse, + ), + request=request, + files=files, + content_type="multipart/form-data", + auth_kwargs=auth_kwargs, + multipart_parser=recraft_multipart_parser, + ) + response: RecraftImageGenerationResponse = await operation.execute() + all_bytesio = [] + if response.image is not None: + all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + else: + for data in response.data: + all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) - return all_bytesio + return all_bytesio def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict: diff --git a/pyproject.toml b/pyproject.toml index 7952f7f37..240919a43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ messages_control.disable = [ "unnecessary-pass", "unidiomatic-typecheck", "unnecessary-lambda-assignment", - "bad-indentation", "no-else-return", "no-else-raise", "invalid-overridden-method", From 3e68bc342cd60b909b4117c1b68a3afc62ef875c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:43:54 +0300 Subject: [PATCH 313/325] convert nodes_torch_compile.py to V3 schema (#10173) --- comfy_extras/nodes_torch_compile.py | 46 +++++++++++++++++++---------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 605536678..adbeece2f 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,23 +1,39 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from comfy_api.torch_helpers import set_torch_compile_wrapper -class TorchCompileModel: +class TorchCompileModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "backend": (["inductor", "cudagraphs"],), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="TorchCompileModel", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Combo.Input( + "backend", + options=["inductor", "cudagraphs"], + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing" - EXPERIMENTAL = True - - def patch(self, model, backend): + @classmethod + def execute(cls, model, backend) -> io.NodeOutput: m = model.clone() set_torch_compile_wrapper(model=m, backend=backend) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TorchCompileModel": TorchCompileModel, -} + +class TorchCompileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TorchCompileModel, + ] + + +async def comfy_entrypoint() -> TorchCompileExtension: + return TorchCompileExtension() From d7aa414141f02a456801704a3da323fa2ed8f5cc Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:45:02 +0300 Subject: [PATCH 314/325] convert nodes_eps.py to V3 schema (#10172) --- comfy_extras/nodes_eps.py | 62 ++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py index c8818f096..7852d85e5 100644 --- a/comfy_extras/nodes_eps.py +++ b/comfy_extras/nodes_eps.py @@ -1,4 +1,9 @@ -class EpsilonScaling: +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class EpsilonScaling(io.ComfyNode): """ Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' (https://arxiv.org/abs/2308.15321v6). @@ -8,26 +13,28 @@ class EpsilonScaling: recommended by the paper for its practicality and effectiveness. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scaling_factor": ("FLOAT", { - "default": 1.005, - "min": 0.5, - "max": 1.5, - "step": 0.001, - "display": "number" - }), - } - } + def define_schema(cls): + return io.Schema( + node_id="Epsilon Scaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "scaling_factor", + default=1.005, + min=0.5, + max=1.5, + step=0.001, + display_mode=io.NumberDisplay.number, + ), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "model_patches/unet" - - def patch(self, model, scaling_factor): + @classmethod + def execute(cls, model, scaling_factor) -> io.NodeOutput: # Prevent division by zero, though the UI's min value should prevent this. if scaling_factor == 0: scaling_factor = 1e-9 @@ -53,8 +60,15 @@ class EpsilonScaling: model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) - return (model_clone,) + return io.NodeOutput(model_clone) -NODE_CLASS_MAPPINGS = { - "Epsilon Scaling": EpsilonScaling -} + +class EpsilonScalingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EpsilonScaling, + ] + +async def comfy_entrypoint() -> EpsilonScalingExtension: + return EpsilonScalingExtension() From 8c26d7bbe6663f589f0a9562921aafb3c48955c6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:48:21 +0300 Subject: [PATCH 315/325] convert nodes_pixverse.py to V3 schema (#10177) --- comfy_api_nodes/nodes_pixverse.py | 471 +++++++++++++++--------------- 1 file changed, 238 insertions(+), 233 deletions(-) diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 7c5a52feb..eb98e9653 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,5 +1,7 @@ from inspect import cleandoc from typing import Optional +from typing_extensions import override +from io import BytesIO from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_bytesio, validate_string, ) -from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import ComfyExtension, io as comfy_io import torch import aiohttp -from io import BytesIO AVERAGE_DURATION_T2V = 32 @@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): return response_upload.Resp.img_id -class PixverseTemplateNode: +class PixverseTemplateNode(comfy_io.ComfyNode): """ Select template for PixVerse Video generation. """ - RETURN_TYPES = (PixverseIO.TEMPLATE,) - RETURN_NAMES = ("pixverse_template",) - FUNCTION = "create_template" - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTemplateNode", + display_name="PixVerse Template", + category="api node/video/PixVerse", + inputs=[ + comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]), + ], + outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "template": (list(pixverse_templates.keys()),), - } - } - - def create_template(self, template: str): + def execute(cls, template: str) -> comfy_io.NodeOutput: template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") # just return the integer - return (template_id,) + return comfy_io.NodeOutput(template_id) -class PixverseTextToVideoNode(ComfyNodeABC): +class PixverseTextToVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTextToVideoNode", + display_name="PixVerse Text to Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[ratio.value for ratio in PixverseAspectRatio], + ), + comfy_io.Combo.Input( + "quality", + options=[resolution.value for resolution in PixverseQuality], + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=[dur.value for dur in PixverseDuration], + ), + comfy_io.Combo.Input( + "motion_mode", + options=[mode.value for mode in PixverseMotionMode], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + comfy_io.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, aspect_ratio: str, quality: str, @@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/video/text/generate", @@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) @@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class PixverseImageToVideoNode(ComfyNodeABC): +class PixverseImageToVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseImageToVideoNode", + display_name="PixVerse Image to Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "quality", + options=[resolution.value for resolution in PixverseQuality], + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=[dur.value for dur in PixverseDuration], + ), + comfy_io.Combo.Input( + "motion_mode", + options=[mode.value for mode in PixverseMotionMode], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + comfy_io.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, quality: str, @@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) - img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + img_id = await upload_image_to_pixverse(image, auth_kwargs=auth) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) @@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class PixverseTransitionVideoNode(ComfyNodeABC): +class PixverseTransitionVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTransitionVideoNode", + display_name="PixVerse Transition Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("first_frame"), + comfy_io.Image.Input("last_frame"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "quality", + options=[resolution.value for resolution in PixverseQuality], + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=[dur.value for dur in PixverseDuration], + ), + comfy_io.Combo.Input( + "motion_mode", + options=[mode.value for mode in PixverseMotionMode], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "first_frame": (IO.IMAGE,), - "last_frame": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, first_frame: torch.Tensor, last_frame: torch.Tensor, prompt: str, @@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC): motion_mode: str, seed, negative_prompt: str = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) - first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth) + last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) @@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -NODE_CLASS_MAPPINGS = { - "PixverseTextToVideoNode": PixverseTextToVideoNode, - "PixverseImageToVideoNode": PixverseImageToVideoNode, - "PixverseTransitionVideoNode": PixverseTransitionVideoNode, - "PixverseTemplateNode": PixverseTemplateNode, -} +class PixVerseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + PixverseTextToVideoNode, + PixverseImageToVideoNode, + PixverseTransitionVideoNode, + PixverseTemplateNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PixverseTextToVideoNode": "PixVerse Text to Video", - "PixverseImageToVideoNode": "PixVerse Image to Video", - "PixverseTransitionVideoNode": "PixVerse Transition Video", - "PixverseTemplateNode": "PixVerse Template", -} + +async def comfy_entrypoint() -> PixVerseExtension: + return PixVerseExtension() From 5c8e986e273d8af8b976fddbaed726e8278cf1fe Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:50:38 +0300 Subject: [PATCH 316/325] convert nodes_tomesd.py to V3 schema (#10180) --- comfy_extras/nodes_tomesd.py | 50 +++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py index 9f77c06fc..87bf29b8f 100644 --- a/comfy_extras/nodes_tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -1,7 +1,9 @@ #Taken from: https://github.com/dbolya/tomesd import torch -from typing import Tuple, Callable +from typing import Tuple, Callable, Optional +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io import math def do_nothing(x: torch.Tensor, mode:str=None): @@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape): -class TomePatchModel: +class TomePatchModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="TomePatchModel", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Model.Output()], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, ratio): - self.u = None + @classmethod + def execute(cls, model, ratio) -> io.NodeOutput: + u: Optional[Callable] = None def tomesd_m(q, k, v, extra_options): + nonlocal u #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q #however from my basic testing it seems that using q instead gives better results - m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + m, u = get_functions(q, ratio, extra_options["original_shape"]) return m(q), k, v def tomesd_u(n, extra_options): - return self.u(n) + nonlocal u + return u(n) m = model.clone() m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_output_patch(tomesd_u) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TomePatchModel": TomePatchModel, -} +class TomePatchModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TomePatchModel, + ] + + +async def comfy_entrypoint() -> TomePatchModelExtension: + return TomePatchModelExtension() From 4614ee09ca1aaca7ee8067d6c5c30695582326ff Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 3 Oct 2025 23:24:42 +0300 Subject: [PATCH 317/325] convert nodes_edit_model.py to V3 schema (#10147) --- comfy_extras/nodes_edit_model.py | 46 ++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/comfy_extras/nodes_edit_model.py b/comfy_extras/nodes_edit_model.py index b69f79715..36da66f34 100644 --- a/comfy_extras/nodes_edit_model.py +++ b/comfy_extras/nodes_edit_model.py @@ -1,26 +1,38 @@ import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class ReferenceLatent: +class ReferenceLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - }, - "optional": {"latent": ("LATENT", ),} - } + def define_schema(cls): + return io.Schema( + node_id="ReferenceLatent", + category="advanced/conditioning/edit_models", + description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/edit_models" - DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images." - - def append(self, conditioning, latent=None): + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: if latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -NODE_CLASS_MAPPINGS = { - "ReferenceLatent": ReferenceLatent, -} +class EditModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ReferenceLatent, + ] + + +def comfy_entrypoint() -> EditModelExtension: + return EditModelExtension() From 93d859cfaaad150c2a1e5e54c8f14765fa79ecb5 Mon Sep 17 00:00:00 2001 From: Finn-Hecker Date: Fri, 3 Oct 2025 23:32:19 +0200 Subject: [PATCH 318/325] Fix type annotation syntax in MotionEncoder_tc __init__ (#10186) ## Summary Fixed incorrect type hint syntax in `MotionEncoder_tc.__init__()` parameter list. ## Changes - Line 647: Changed `num_heads=int` to `num_heads: int` - This corrects the parameter annotation from a default value assignment to proper type hint syntax ## Details The parameter was using assignment syntax (`=`) instead of type annotation syntax (`:`), which would incorrectly set the default value to the `int` class itself rather than annotating the expected type. --- comfy/ldm/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0dc650ced..90c347d3d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -903,7 +903,7 @@ class MotionEncoder_tc(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, - num_heads=int, + num_heads: int, need_global=True, dtype=None, device=None, From 08726b64fe767f47bf074a05bedd6db45314c4c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 3 Oct 2025 15:22:43 -0700 Subject: [PATCH 319/325] Update amd nightly command in readme. (#10189) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f24a33ee..1224a6176 100644 --- a/README.md +++ b/README.md @@ -211,9 +211,9 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0``` ### Intel GPUs (Windows and Linux) From bbd683098e7d18700f025b2f0a4f6a44a3176602 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 3 Oct 2025 20:37:43 -0700 Subject: [PATCH 320/325] Add instructions to install nightly AMD pytorch for windows. (#10190) * Add instructions to install nightly AMD pytorch for windows. * Update README.md --- README.md | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 1224a6176..4a5a17cda 100644 --- a/README.md +++ b/README.md @@ -206,7 +206,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -### AMD GPUs (Linux only) +### AMD GPUs (Linux) + AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` @@ -215,6 +216,23 @@ This is the command to install the nightly with ROCm 7.0 which might have some p ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0``` + +### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. + +These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware. + +RDNA 3 (RX 7000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` + +RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/``` + +RDNA 4 (RX 9000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/``` + ### Intel GPUs (Windows and Linux) (Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) @@ -270,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). -#### DirectML (AMD Cards on Windows) - -This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out. - -```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` - #### Ascend NPUs For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: From 22f99fb97edaccf450152c8bf7c4068c1d331899 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 4 Oct 2025 22:22:57 +0300 Subject: [PATCH 321/325] fix(api-nodes): enable 2 more pylint rules, removed non needed code (#10192) --- comfy_api_nodes/nodes_gemini.py | 3 +- comfy_api_nodes/nodes_moonvalley.py | 49 ++--------------------------- pyproject.toml | 2 -- 3 files changed, 4 insertions(+), 50 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 151cb4044..309e9a2d2 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_base64_string, bytesio_to_image_tensor, ) +from comfy_api.util import VideoContainer, VideoCodec GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" @@ -310,7 +311,7 @@ class GeminiNode(ComfyNodeABC): Returns: List of GeminiPart objects containing the encoded video. """ - from comfy_api.util import VideoContainer, VideoCodec + base_64_string = video_to_base64_string( video_input, container_format=VideoContainer.MP4, diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 08e838fef..6467dd614 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -2,11 +2,7 @@ import logging from typing import Any, Callable, Optional, TypeVar import torch from typing_extensions import override -from comfy_api_nodes.util.validation_utils import ( - get_image_dimensions, - validate_image_dimensions, -) - +from comfy_api_nodes.util.validation_utils import validate_image_dimensions from comfy_api_nodes.apis import ( MoonvalleyTextToVideoRequest, @@ -132,47 +128,6 @@ def validate_prompts( return True -def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None): - # inference validation - # T = num_frames - # in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition... - # with image conditioning: H*W must be divisible by 8192 - # without image conditioning: T divisible by 32 - if num_frames_in and not num_frames_in % 16 == 0: - return False, ("The input video total frame count must be divisible by 16!") - - if height % 8 != 0 or width % 8 != 0: - return False, ( - f"Height ({height}) and width ({width}) must be " "divisible by 8" - ) - - if with_frame_conditioning: - if (height * width) % 8192 != 0: - return False, ( - f"Height * width ({height * width}) must be " - "divisible by 8192 for frame conditioning" - ) - else: - if num_frames_in and not num_frames_in % 32 == 0: - return False, ("The input video total frame count must be divisible by 32!") - - -def validate_input_image( - image: torch.Tensor, with_frame_conditioning: bool = False -) -> None: - """ - Validates the input image adheres to the expectations of the API: - - The image resolution should not be less than 300*300px - - The aspect ratio of the image should be between 1:2.5 ~ 2.5:1 - - """ - height, width = get_image_dimensions(image) - validate_input_media(width, height, with_frame_conditioning) - validate_image_dimensions( - image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH - ) - - def validate_video_to_video_input(video: VideoInput) -> VideoInput: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -499,7 +454,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): seed: int, steps: int, ) -> comfy_io.NodeOutput: - validate_input_image(image, True) + validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) diff --git a/pyproject.toml b/pyproject.toml index 240919a43..383e7d10a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,5 @@ messages_control.disable = [ "invalid-overridden-method", "unused-variable", "pointless-string-statement", - "inconsistent-return-statements", - "import-outside-toplevel", "redefined-outer-name", ] From 2ed74f7ac78d3ff713d0a8583695c31055914b76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 4 Oct 2025 22:29:09 +0300 Subject: [PATCH 322/325] convert nodes_rodin.py to V3 schema (#10195) --- comfy_api_nodes/nodes_rodin.py | 941 +++++++++++++++++---------------- 1 file changed, 478 insertions(+), 463 deletions(-) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 633ac46d3..bd758f762 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -7,14 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/ from __future__ import annotations from inspect import cleandoc -from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths import aiohttp import os import asyncio -import io import logging import math +from typing import Optional +from io import BytesIO +from typing_extensions import override from PIL import Image from comfy_api_nodes.apis.rodin_api import ( Rodin3DGenerateRequest, @@ -31,428 +32,436 @@ from comfy_api_nodes.apis.client import ( SynchronousOperation, PollingOperation, ) +from comfy_api.latest import ComfyExtension, io as comfy_io -COMMON_PARAMETERS = { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } +COMMON_PARAMETERS = [ + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], + default="18K-Quad", + optional=True, ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], - "default": "18K-Quad" - } +] + + +def get_quality_mode(poly_count): + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": + mesh_mode = "Raw" + elif poly == "Quad": + mesh_mode = "Quad" + else: + mesh_mode = "Quad" + + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override + + +def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): + """ + Converts a PyTorch tensor to a file-like object. + + Args: + - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) + where C is the number of channels (3 for RGB), H is height, and W is width. + + Returns: + - io.BytesIO: A file-like object containing the image data. + """ + array = tensor.cpu().numpy() + array = (array * 255).astype('uint8') + image = Image.fromarray(array, 'RGB') + + original_width, original_height = image.size + original_pixels = original_width * original_height + if original_pixels > max_pixels: + scale = math.sqrt(max_pixels / original_pixels) + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + new_width, new_height = original_width, original_height + + if new_width != original_width or new_height != original_height: + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + img_byte_arr = BytesIO() + image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + img_byte_arr.seek(0) + return img_byte_arr + + +async def create_generate_task( + images=None, + seed=1, + material="PBR", + quality_override=18000, + tier="Regular", + mesh_mode="Quad", + TAPose = False, + auth_kwargs: Optional[dict[str, str]] = None, +): + if images is None: + raise Exception("Rodin 3D generate requires at least 1 image.") + if len(images) > 5: + raise Exception("Rodin 3D generate requires up to 5 image.") + + path = "/proxy/rodin/api/v2/rodin" + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=Rodin3DGenerateRequest, + response_model=Rodin3DGenerateResponse, + ), + request=Rodin3DGenerateRequest( + seed=seed, + tier=tier, + material=material, + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=TAPose, + ), + files=[ + ( + "images", + open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) + ) + for image in images if image is not None + ], + content_type="multipart/form-data", + auth_kwargs=auth_kwargs, ) -} -def create_task_error(response: Rodin3DGenerateResponse): - """Check if the response has error""" - return hasattr(response, "error") + response = await operation.execute() + + if hasattr(response, "error"): + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + logging.error(error_message) + raise Exception(error_message) + + logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") + subscription_key = response.jobs.subscription_key + task_uuid = response.uuid + logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") + return task_uuid, subscription_key -class Rodin3DAPI: - """ - Generate 3D Assets using Rodin API - """ - RETURN_TYPES = (IO.STRING,) - RETURN_NAMES = ("3D Model Path",) - CATEGORY = "api node/3d/Rodin" - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "api_call" - API_NODE = True - - def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048): - """ - Converts a PyTorch tensor to a file-like object. - - Args: - - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) - where C is the number of channels (3 for RGB), H is height, and W is width. - - Returns: - - io.BytesIO: A file-like object containing the image data. - """ - array = tensor.cpu().numpy() - array = (array * 255).astype('uint8') - image = Image.fromarray(array, 'RGB') - - original_width, original_height = image.size - original_pixels = original_width * original_height - if original_pixels > max_pixels: - scale = math.sqrt(max_pixels / original_pixels) - new_width = int(original_width * scale) - new_height = int(original_height * scale) - else: - new_width, new_height = original_width, original_height - - if new_width != original_width or new_height != original_height: - image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression - img_byte_arr.seek(0) - return img_byte_arr - - def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: - has_failed = any(job.status == JobStatus.Failed for job in response.jobs) - all_done = all(job.status == JobStatus.Done for job in response.jobs) - status_list = [str(job.status) for job in response.jobs] - logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") - if has_failed: - logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") - raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") - elif all_done: - return "DONE" - else: - return "Generating" - - async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs): - if images is None: - raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) > 5: - raise Exception("Rodin 3D generate requires up to 5 image.") - - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( - seed=seed, - tier=tier, - material=material, - quality_override=quality_override, - mesh_mode=mesh_mode, - TAPose=TAPose, - ), - files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image) - ) - for image in images if image is not None - ], - content_type = "multipart/form-data", - auth_kwargs=kwargs, - ) - - response = await operation.execute() - - if create_task_error(response): - error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" - logging.error(error_message) - raise Exception(error_message) - - logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") - subscription_key = response.jobs.subscription_key - task_uuid = response.uuid - logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") - return task_uuid, subscription_key - - async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: - - path = "/proxy/rodin/api/v2/status" - - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path = path, - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest( - subscription_key = subscription_key - ), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=self.check_rodin_status, - poll_interval=3.0, - auth_kwargs=kwargs, - ) - - logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - - return await poll_operation.execute() - - async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - - path = "/proxy/rodin/api/v2/download" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest( - task_uuid=uuid - ), - auth_kwargs=kwargs - ) - - return await operation.execute() - - def get_quality_mode(self, poly_count): - polycount = poly_count.split("-") - poly = polycount[1] - count = polycount[0] - if poly == "Triangle": - mesh_mode = "Raw" - elif poly == "Quad": - mesh_mode = "Quad" - else: - mesh_mode = "Quad" - - if count == "4K": - quality_override = 4000 - elif count == "8K": - quality_override = 8000 - elif count == "18K": - quality_override = 18000 - elif count == "50K": - quality_override = 50000 - elif count == "2K": - quality_override = 2000 - elif count == "20K": - quality_override = 20000 - elif count == "150K": - quality_override = 150000 - elif count == "500K": - quality_override = 500000 - else: - quality_override = 18000 - - return mesh_mode, quality_override - - async def download_files(self, url_list, task_uuid): - save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") - os.makedirs(save_path, exist_ok=True) - model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - url = i.url - file_name = i.name - file_path = os.path.join(save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) - - return model_file_path +def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: + all_done = all(job.status == JobStatus.Done for job in response.jobs) + status_list = [str(job.status) for job in response.jobs] + logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") + if any(job.status == JobStatus.Failed for job in response.jobs): + logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") + raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") + if all_done: + return "DONE" + return "Generating" -class Rodin3D_Regular(Rodin3DAPI): +async def poll_for_task_status( + subscription_key, auth_kwargs: Optional[dict[str, str]] = None, +) -> Rodin3DCheckStatusResponse: + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/status", + method=HttpMethod.POST, + request_model=Rodin3DCheckStatusRequest, + response_model=Rodin3DCheckStatusResponse, + ), + request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + completed_statuses=["DONE"], + failed_statuses=["FAILED"], + status_extractor=check_rodin_status, + poll_interval=3.0, + auth_kwargs=auth_kwargs, + ) + logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") + return await poll_operation.execute() + + +async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/download", + method=HttpMethod.POST, + request_model=Rodin3DDownloadRequest, + response_model=Rodin3DDownloadResponse, + ), + request=Rodin3DDownloadRequest(task_uuid=uuid), + auth_kwargs=auth_kwargs, + ) + return await operation.execute() + + +async def download_files(url_list, task_uuid): + save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") + os.makedirs(save_path, exist_ok=True) + model_file_path = None + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) + return model_file_path + + +class Rodin3D_Regular(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Regular", + display_name="Rodin 3D Generate - Regular Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) - async def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Regular" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Detail(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Detail(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Detail", + display_name="Rodin 3D Generate - Detail Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Detail" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Smooth(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Smooth(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Smooth", + display_name="Rodin 3D Generate - Smooth Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Sketch(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": - ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Sketch(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Sketch", + display_name="Rodin 3D Generate - Sketch Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Sketch" num_images = Images.shape[0] m_images = [] @@ -461,104 +470,110 @@ class Rodin3D_Sketch(Rodin3DAPI): material_type = "PBR" quality_override = 18000 mesh_mode = "Quad" - task_uuid, subscription_key = await self.create_generate_task( - images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs - ) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - -class Rodin3D_Gen2(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } - ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], - "default": "500K-Triangle" - } - ), - "TAPose": ( - IO.BOOLEAN, - { - "default": False, - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=material_type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Gen2(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Gen2", + display_name="Rodin 3D Generate - Gen-2 Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + default="500K-Triangle", + optional=True, + ), + comfy_io.Boolean.Input("TAPose", default=False), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, TAPose, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Gen-2" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + TAPose=TAPose, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - return (model,) + return comfy_io.NodeOutput(model) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Rodin3D_Regular": Rodin3D_Regular, - "Rodin3D_Detail": Rodin3D_Detail, - "Rodin3D_Smooth": Rodin3D_Smooth, - "Rodin3D_Sketch": Rodin3D_Sketch, - "Rodin3D_Gen2": Rodin3D_Gen2, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate", - "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", - "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", - "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", - "Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", -} +class Rodin3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + Rodin3D_Regular, + Rodin3D_Detail, + Rodin3D_Smooth, + Rodin3D_Sketch, + Rodin3D_Gen2, + ] + + +async def comfy_entrypoint() -> Rodin3DExtension: + return Rodin3DExtension() From b1fa1922df597af759150f4e26ecb276c9753ee4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 4 Oct 2025 22:33:48 +0300 Subject: [PATCH 323/325] convert nodes_stable3d.py to V3 schema (#10204) --- comfy_extras/nodes_stable3d.py | 149 +++++++++++++++++++-------------- 1 file changed, 86 insertions(+), 63 deletions(-) diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index be2e34c28..c6d8a683d 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -1,6 +1,8 @@ import torch import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def camera_embeddings(elevation, azimuth): elevation = torch.as_tensor([elevation]) @@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth): return embeddings -class StableZero123_Conditioning: +class StableZero123_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -51,30 +58,35 @@ class StableZero123_Conditioning: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -class StableZero123_Conditioning_Batched: +class StableZero123_Conditioning_Batched(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning_Batched", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -93,27 +105,32 @@ class StableZero123_Conditioning_Batched: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) -class SV3D_Conditioning: +class SV3D_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SV3D_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("video_frames", default=21, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -133,11 +150,17 @@ class SV3D_Conditioning: positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] latent = torch.zeros([video_frames, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "StableZero123_Conditioning": StableZero123_Conditioning, - "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, - "SV3D_Conditioning": SV3D_Conditioning, -} +class Stable3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableZero123_Conditioning, + StableZero123_Conditioning_Batched, + SV3D_Conditioning, + ] + +async def comfy_entrypoint() -> Stable3DExtension: + return Stable3DExtension() From caf07331ff1b20f4104b9693ed244d6e22f80b5a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 4 Oct 2025 19:05:05 -0700 Subject: [PATCH 324/325] Remove soundfile dependency. No more torchaudio load or save. (#10210) --- comfy_extras/nodes_audio.py | 2 +- requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 51c8b9dd9..1c868fcba 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -360,7 +360,7 @@ class RecordAudio: def load(self, audio): audio_path = folder_paths.get_annotated_filepath(audio) - waveform, sample_rate = torchaudio.load(audio_path) + waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, ) diff --git a/requirements.txt b/requirements.txt index 588c5dcf0..6c28f9478 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,5 @@ av>=14.2.0 #non essential dependencies: kornia>=0.7.1 spandrel -soundfile pydantic~=2.0 pydantic-settings~=2.0 From 187f43696dd58f252075d2e3c6873706eb6b5fa1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 5 Oct 2025 09:34:18 +0300 Subject: [PATCH 325/325] fix(api-nodes): disable "std" mode for Kling2.5-turbo (#10212) --- comfy_api_nodes/nodes_kling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index d8646f106..44fccc0c7 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -712,6 +712,9 @@ class KlingImage2VideoNode(KlingNodeBase): # Camera control type for image 2 video is always `simple` camera_control.type = KlingCameraControlType.simple + if mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: + mode = "pro" # October 5: currently "std" mode is not supported for this model + initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_IMAGE_TO_VIDEO,