From 9cdc64998f8990aed7688b0ebe89bc3b97733764 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:15:23 -0700 Subject: [PATCH 001/231] Only disable cudnn on newer AMD GPUs. (#10437) --- comfy/model_management.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a2c318ec3..79d6ff9d4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -330,15 +330,21 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute + +AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] + try: if is_amd(): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: rocm_version = (6, -1) - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName - if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 + if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16 if manual_cast: return True return False From f13cff0be65e35d34876b173bba2fec6bd94746b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 21 Oct 2025 20:16:16 -0700 Subject: [PATCH 002/231] Add custom node published subgraphs endpoint (#10438) * Add get_subgraphs_dir to ComfyExtension and PUBLISHED_SUBGRAPH_DIRS to nodes.py * Created initial endpoints, although the returned paths are a bit off currently * Fix path and actually return real data * Sanitize returned /api/global_subgraphs entries * Remove leftover function from early prototyping * Remove added whitespace * Add None check for sanitize_entry --- app/subgraph_manager.py | 112 ++++++++++++++++++++++++++++++++++++++++ server.py | 3 ++ 2 files changed, 115 insertions(+) create mode 100644 app/subgraph_manager.py diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py new file mode 100644 index 000000000..dbe404541 --- /dev/null +++ b/app/subgraph_manager.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import TypedDict +import os +import folder_paths +import glob +from aiohttp import web +import hashlib + + +class Source: + custom_node = "custom_node" + +class SubgraphEntry(TypedDict): + source: str + """ + Source of subgraph - custom_nodes vs templates. + """ + path: str + """ + Relative path of the subgraph file. + For custom nodes, will be the relative directory like /subgraphs/.json + """ + name: str + """ + Name of subgraph file. + """ + info: CustomNodeSubgraphEntryInfo + """ + Additional info about subgraph; in the case of custom_nodes, will contain nodepack name + """ + data: str + +class CustomNodeSubgraphEntryInfo(TypedDict): + node_pack: str + """Node pack name.""" + +class SubgraphManager: + def __init__(self): + self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + + async def load_entry_data(self, entry: SubgraphEntry): + with open(entry['path'], 'r') as f: + entry['data'] = f.read() + return entry + + async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: + if entry is None: + return None + entry = entry.copy() + entry.pop('path', None) + if remove_data: + entry.pop('data', None) + return entry + + async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: + entries = entries.copy() + for key in list(entries.keys()): + entries[key] = await self.sanitize_entry(entries[key], remove_data) + return entries + + async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): + # if not forced to reload and cached, return cache + if not force_reload and self.cached_custom_node_subgraphs is not None: + return self.cached_custom_node_subgraphs + # Load subgraphs from custom nodes + subfolder = "subgraphs" + subgraphs_dict: dict[SubgraphEntry] = {} + + for folder in folder_paths.get_folder_paths("custom_nodes"): + pattern = os.path.join(folder, f"*/{subfolder}/*.json") + matched_files = glob.glob(pattern) + for file in matched_files: + # replace backslashes with forward slashes + file = file.replace('\\', '/') + info: CustomNodeSubgraphEntryInfo = { + "node_pack": "custom_nodes." + file.split('/')[-3] + } + source = Source.custom_node + # hash source + path to make sure id will be as unique as possible, but + # reproducible across backend reloads + id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": Source.custom_node, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": info, + } + subgraphs_dict[id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_custom_node_subgraph(self, id: str, loadedModules): + subgraphs = await self.get_custom_node_subgraphs(loadedModules) + entry: SubgraphEntry = subgraphs.get(id, None) + if entry is not None and entry.get('data', None) is None: + await self.load_entry_data(entry) + return entry + + def add_routes(self, routes, loadedModules): + @routes.get("/global_subgraphs") + async def get_global_subgraphs(request): + subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) + # NOTE: we may want to include other sources of global subgraphs such as templates in the future; + # that's the reasoning for the current implementation + return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) + + @routes.get("/global_subgraphs/{id}") + async def get_global_subgraph(request): + id = request.match_info.get("id", None) + subgraph = await self.get_custom_node_subgraph(id, loadedModules) + return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/server.py b/server.py index 10c2698b5..fe58db286 100644 --- a/server.py +++ b/server.py @@ -35,6 +35,7 @@ from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager +from app.subgraph_manager import SubgraphManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -173,6 +174,7 @@ class PromptServer(): self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() + self.subgraph_manager = SubgraphManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -819,6 +821,7 @@ class PromptServer(): self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) + self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. From 4739d7717fea56750d0ef98c64268d9c1e487d78 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 23 Oct 2025 05:49:05 +1000 Subject: [PATCH 003/231] execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (Resubmit) (#10440) * execution: fold in dependency aware caching This makes --cache-none compatiable with lazy and expanded subgraphs. Currently the --cache-none option is powered by the DependencyAwareCache. The cache attempts to maintain a parallel copy of the execution list data structure, however it is only setup once at the start of execution and does not get meaninigful updates to the execution list. This causes multiple problems when --cache-none is used with lazy and expanded subgraphs as the DAC does not accurately update its copy of the execution data structure. DAC has an attempt to handle subgraphs ensure_subcache however this does not accurately connect to nodes outside the subgraph. The current semantics of DAC are to free a node ASAP after the dependent nodes are executed. This means that if a subgraph refs such a node it will be requed and re-executed by the execution_list but DAC wont see it in its to-free lists anymore and leak memory. Rather than try and cover all the cases where the execution list changes from inside the cache, move the while problem to the executor which maintains an always up-to-date copy of the wanted data-structure. The executor now has a fast-moving run-local cache of its own. Each _to node has its own mini cache, and the cache is unconditionally primed at the time of add_strong_link. add_strong_link is called for all of static workflows, lazy links and expanded subgraphs so its the singular source of truth for output dependendencies. In the case of a cache-hit, the executor cache will hold the non-none value (it will respect updates if they happen somehow as well). In the case of a cache-miss, the executor caches a None and will wait for a notification to update the value when the node completes. When a node completes execution, it simply releases its mini-cache and in turn its strong refs on its direct anscestor outputs, allowing for ASAP freeing (same as the DependencyAwareCache but a little more automatic). This now allows for re-implementation of --cache-none with no cache at all. The dependency aware cache was also observing the dependency sematics for the objects and UI cache which is not accurate (this entire logic was always outputs specific). This also prepares for more complex caching strategies (such as RAM pressure based caching), where a cache can implement any freeing strategy completely independently of the DepedancyAwareness requirement. * main: re-implement --cache-none as no cache at all The execution list now tracks the dependency aware caching more correctly that the DependancyAwareCache. Change it to a cache that does nothing. * test_execution: add --cache-none to the test suite --cache-none is now expected to work universally. Run it through the full unit test suite. Propagate the server parameterization for whether or not the server is capabale of caching, so that the minority of tests that specifically check for cache hits can if else. Hard assert NOT caching in the else to give some coverage of --cache-none expected behaviour to not acutally cache. --- comfy_execution/caching.py | 174 ++++-------------------------- comfy_execution/graph.py | 32 +++++- execution.py | 34 +++--- main.py | 2 +- tests/execution/test_execution.py | 50 +++++---- 5 files changed, 102 insertions(+), 190 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 41224ce3b..566bc3f9c 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,6 +265,26 @@ class HierarchicalCache(BasicCache): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self + class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -316,157 +336,3 @@ class LRUCache(BasicCache): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - - -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ - - def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - await super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - async def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = await super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] - - def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass - - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. - - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index f4b427265..341c9735d 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,8 +153,9 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -194,10 +195,35 @@ class ExecutionList(TopologicalSort): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_output_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + return self.execution_cache[to_node_id].get(from_node_id) + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + if to_node_id in self.execution_cache: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -277,6 +303,8 @@ class ExecutionList(TopologicalSort): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/execution.py b/execution.py index 1dc35738b..78c36a4b0 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ from comfy_execution.caching import ( BasicCache, CacheKeySetID, CacheKeySetInputSignature, - DependencyAwareCache, + NullCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ class IsChangedCache: class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + if cache_type == CacheType.NONE: + self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,11 +120,12 @@ class CacheSet: self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_null_cache(self): + self.outputs = NullCache() + #The UI cache is expected to be iterable at the end of each workflow + #so it must cache at least a full workflow. Use Heirachical + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = NullCache() def recursive_debug_dump(self): result = { @@ -135,7 +136,7 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -153,10 +154,10 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) + cached_output = execution_list.get_output_cache(input_unique_id, unique_id) if cached_output is None: mark_missing() continue @@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] + node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] for o in node_output: resolved_output.append(o) @@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -549,11 +551,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) + execution_list.cache_update(unique_id, output_data) + except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 35857dba8..4b4c5dcc4 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ef73ad9fd..ace0d2279 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - # (use_lru, lru_size) - (False, 0), - (True, 0), - (True, 100), + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) - def _server(self, args_pytest, request): + def server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,12 +167,10 @@ class TestExecution: '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] + pargs += [ str(param) for param in request.param["extra_args"] ] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield + yield request.param p.kill() torch.cuda.empty_cache() @@ -193,7 +191,7 @@ class TestExecution: return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): + def shared_client(self, args_pytest, server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -225,7 +223,7 @@ class TestExecution: assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -237,9 +235,12 @@ class TestExecution: client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -251,8 +252,12 @@ class TestExecution: client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -411,7 +416,7 @@ class TestExecution: input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -427,7 +432,10 @@ class TestExecution: result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -514,7 +522,7 @@ class TestExecution: assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -530,7 +538,11 @@ class TestExecution: images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized From a1864c01f29cc43fe6bf823fc3fd46ba2781c2e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 22 Oct 2025 14:26:22 -0700 Subject: [PATCH 004/231] Small readme improvement. (#10442) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c9a0644e3..434d4ff06 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,8 @@ Python 3.14 will work if you comment out the `kornia` dependency in the requirem Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +### Instructions: + Git clone this repo. Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints From 1bcda6df987a6c92b39d8b6d29e0b029450d67d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 23 Oct 2025 18:21:14 -0700 Subject: [PATCH 005/231] WIP way to support multi multi dimensional latents. (#10456) --- comfy/model_base.py | 10 ++++- comfy/nested_tensor.py | 91 ++++++++++++++++++++++++++++++++++++++++++ comfy/sample.py | 27 ++++++++++--- comfy/samplers.py | 23 +++++++---- comfy/utils.py | 22 ++++++++++ 5 files changed, 158 insertions(+), 15 deletions(-) create mode 100644 comfy/nested_tensor.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea..e877f19ac 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -197,8 +197,14 @@ class BaseModel(torch.nn.Module): extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - return self.model_sampling.calculate_denoised(sigma, model_output, x) + if "latent_shapes" in extra_conds: + xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + if len(model_output) > 1 and not torch.is_tensor(model_output): + model_output, _ = utils.pack_latents(model_output) + + return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py new file mode 100644 index 000000000..b700816fa --- /dev/null +++ b/comfy/nested_tensor.py @@ -0,0 +1,91 @@ +import torch + +class NestedTensor: + def __init__(self, tensors): + self.tensors = list(tensors) + self.is_nested = True + + def _copy(self): + return NestedTensor(self.tensors) + + def apply_operation(self, other, operation): + o = self._copy() + if isinstance(other, NestedTensor): + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other.tensors[i]) + else: + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other) + return o + + def __add__(self, b): + return self.apply_operation(b, lambda x, y: x + y) + + def __sub__(self, b): + return self.apply_operation(b, lambda x, y: x - y) + + def __mul__(self, b): + return self.apply_operation(b, lambda x, y: x * y) + + # def __itruediv__(self, b): + # return self.apply_operation(b, lambda x, y: x / y) + + def __truediv__(self, b): + return self.apply_operation(b, lambda x, y: x / y) + + def __getitem__(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) + + def unbind(self): + return self.tensors + + def to(self, *args, **kwargs): + o = self._copy() + for i, t in enumerate(o.tensors): + o.tensors[i] = t.to(*args, **kwargs) + return o + + def new_ones(self, *args, **kwargs): + return self.tensors[0].new_ones(*args, **kwargs) + + def float(self): + return self.to(dtype=torch.float) + + def chunk(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) + + def size(self): + return self.tensors[0].size() + + @property + def shape(self): + return self.tensors[0].shape + + @property + def ndim(self): + dims = 0 + for t in self.tensors: + dims = max(t.ndim, dims) + return dims + + @property + def device(self): + return self.tensors[0].device + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def layout(self): + return self.tensors[0].layout + + +def cat_nested(tensors, *args, **kwargs): + cated_tensors = [] + for i in range(len(tensors[0].tensors)): + tens = [] + for j in range(len(tensors)): + tens.append(tensors[j].tensors[i]) + cated_tensors.append(torch.cat(tens, *args, **kwargs)) + return NestedTensor(cated_tensors) diff --git a/comfy/sample.py b/comfy/sample.py index be5a7e246..b1395da84 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -4,13 +4,9 @@ import comfy.samplers import comfy.utils import numpy as np import logging +import comfy.nested_tensor -def prepare_noise(latent_image, seed, noise_inds=None): - """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed - """ - generator = torch.manual_seed(seed) +def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") @@ -22,9 +18,28 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises.append(noise) noises = [noises[i] for i in inverse] noises = torch.cat(noises, axis=0) + +def prepare_noise(latent_image, seed, noise_inds=None): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.manual_seed(seed) + + if latent_image.is_nested: + tensors = latent_image.unbind() + noises = [] + for t in tensors: + noises.append(prepare_noise_inner(t, generator, noise_inds)) + noises = comfy.nested_tensor.NestedTensor(noises) + else: + noises = prepare_noise_inner(latent_image, generator, noise_inds) + return noises def fix_empty_latent_channels(model, latent_image): + if latent_image.is_nested: + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index e7efaf470..fa4640842 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -782,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None): for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) @@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if hasattr(model, 'extra_conds'): for k in conds: - conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) #make sure each cond area has an opposite one with the same area for k in conds: @@ -962,11 +962,11 @@ class CFGGuider: def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) - self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas @@ -980,7 +980,7 @@ class CFGGuider: samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device @@ -994,7 +994,7 @@ class CFGGuider: try: self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: self.model_patcher.cleanup() @@ -1007,6 +1007,12 @@ class CFGGuider: if sigmas.shape[-1] == 0: return latent_image + if latent_image.is_nested: + latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind()) + noise, _ = comfy.utils.pack_latents(noise.unbind()) + else: + latent_shapes = [latent_image.shape] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -1026,7 +1032,7 @@ class CFGGuider: self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options @@ -1034,6 +1040,9 @@ class CFGGuider: self.model_patcher.restore_hook_patches() del self.conds + + if len(latent_shapes) > 1: + output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes)) return output diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165..4bd281057 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + +def pack_latents(latents): + latent_shapes = [] + tensors = [] + for tensor in latents: + latent_shapes.append(tensor.shape) + tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) + + latent = torch.cat(tensors, dim=-1) + return latent, latent_shapes + +def unpack_latents(combined_latent, latent_shapes): + if len(latent_shapes) > 1: + output_tensors = [] + for shape in latent_shapes: + cut = math.prod(shape[1:]) + tens = combined_latent[:, :, :cut] + combined_latent = combined_latent[:, :, cut:] + output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) + else: + output_tensors = combined_latent + return output_tensors From 24188b3141aace272cb91b85578c76f5a8f70e1c Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 24 Oct 2025 13:36:30 +0800 Subject: [PATCH 006/231] Update template to 0.2.2 (#10461) Fix template typo issue --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dd2afcab0..8570c66b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.2.1 +comfyui-workflow-templates==0.2.2 comfyui-embedded-docs==0.3.0 torch torchsde From 388b306a2b48070737b092b51e76de933baee9ad Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 24 Oct 2025 08:37:16 +0300 Subject: [PATCH 007/231] feat(api-nodes): network client v2: async ops, cancellation, downloads, refactor (#10390) * feat(api-nodes): implement new API client for V3 nodes * feat(api-nodes): implement new API client for V3 nodes * feat(api-nodes): implement new API client for V3 nodes * converted WAN nodes to use new client; polishing * fix(auth): do not leak authentification for the absolute urls * convert BFL API nodes to use new API client; remove deprecated BFL nodes * converted Google Veo nodes * fix(Veo3.1 model): take into account "generate_audio" parameter --- comfy_api_nodes/apinode_utils.py | 435 +--------- comfy_api_nodes/apis/bfl_api.py | 51 +- comfy_api_nodes/apis/veo_api.py | 111 +++ comfy_api_nodes/nodes_bfl.py | 605 ++++---------- comfy_api_nodes/nodes_bytedance.py | 279 ++----- comfy_api_nodes/nodes_gemini.py | 5 +- comfy_api_nodes/nodes_kling.py | 350 +++----- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 366 ++------- comfy_api_nodes/nodes_openai.py | 4 +- comfy_api_nodes/nodes_pika.py | 6 +- comfy_api_nodes/nodes_pixverse.py | 13 +- comfy_api_nodes/nodes_recraft.py | 4 +- comfy_api_nodes/nodes_runway.py | 199 ++--- comfy_api_nodes/nodes_sora.py | 74 +- comfy_api_nodes/nodes_stability.py | 8 +- comfy_api_nodes/nodes_veo2.py | 176 ++-- comfy_api_nodes/nodes_vidu.py | 131 +-- comfy_api_nodes/nodes_wan.py | 245 +++--- comfy_api_nodes/util/__init__.py | 87 ++ comfy_api_nodes/util/_helpers.py | 71 ++ comfy_api_nodes/util/client.py | 941 ++++++++++++++++++++++ comfy_api_nodes/util/common_exceptions.py | 14 + comfy_api_nodes/util/conversions.py | 407 ++++++++++ comfy_api_nodes/util/download_helpers.py | 249 ++++++ comfy_api_nodes/util/upload_helpers.py | 338 ++++++++ comfy_api_nodes/util/validation_utils.py | 58 +- pyproject.toml | 2 + 29 files changed, 2935 insertions(+), 2298 deletions(-) create mode 100644 comfy_api_nodes/apis/veo_api.py create mode 100644 comfy_api_nodes/util/_helpers.py create mode 100644 comfy_api_nodes/util/client.py create mode 100644 comfy_api_nodes/util/common_exceptions.py create mode 100644 comfy_api_nodes/util/conversions.py create mode 100644 comfy_api_nodes/util/download_helpers.py create mode 100644 comfy_api_nodes/util/upload_helpers.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index bc3d2d07e..e3d282059 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,15 +1,10 @@ from __future__ import annotations import aiohttp -import io -import logging import mimetypes -import os from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api.input_impl import VideoFromFile from comfy_api.util import VideoContainer, VideoCodec from comfy_api.input.video_types import VideoInput -from comfy_api.input.basic_types import AudioInput from comfy_api_nodes.apis.client import ( ApiClient, ApiEndpoint, @@ -26,43 +21,8 @@ from PIL import Image import torch import math import base64 -import uuid +from .util import tensor_to_bytesio, bytesio_to_image_tensor from io import BytesIO -import av - - -async def download_url_to_video_output( - video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> VideoFromFile: - """Downloads a video from a URL and returns a `VIDEO` output. - - Args: - video_url: The URL of the video to download. - - Returns: - A Comfy node `VIDEO` output. - """ - video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs) - if video_io is None: - error_msg = f"Failed to download video from {video_url}" - logging.error(error_msg) - raise ValueError(error_msg) - return VideoFromFile(video_io) - - -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" - samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s async def validate_and_cast_response( @@ -162,11 +122,6 @@ def validate_aspect_ratio( return aspect_ratio -def mimetype_to_extension(mime_type: str) -> str: - """Converts a MIME type to a file extension.""" - return mime_type.split("/")[-1].lower() - - async def download_url_to_bytesio( url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None ) -> BytesIO: @@ -195,136 +150,11 @@ async def download_url_to_bytesio( return BytesIO(await resp.read()) -def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: - """Converts image data from BytesIO to a torch.Tensor. - - Args: - image_bytesio: BytesIO object containing the image data. - mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - PIL.UnidentifiedImageError: If the image data cannot be identified. - ValueError: If the specified mode is invalid. - """ - image = Image.open(image_bytesio) - image = image.convert(mode) - image_array = np.array(image).astype(np.float32) / 255.0 - return torch.from_numpy(image_array).unsqueeze(0) - - -async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: - """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = await download_url_to_bytesio(url, timeout) - return bytesio_to_image_tensor(image_bytesio) - - def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" return bytesio_to_image_tensor(BytesIO(response_content)) -def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: - """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" - if len(image.shape) > 3: - image = image[0] - # TODO: remove alpha if not allowed and present - input_tensor = image.cpu() - input_tensor = downscale_image_tensor( - input_tensor.unsqueeze(0), total_pixels=total_pixels - ).squeeze() - image_np = (input_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - return img - - -def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: - """Converts a PIL Image to a BytesIO object.""" - if not mime_type: - mime_type = "image/png" - - img_byte_arr = io.BytesIO() - # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') - pil_format = mime_type.split("/")[-1].upper() - if pil_format == "JPG": - pil_format = "JPEG" - img.save(img_byte_arr, format=pil_format) - img_byte_arr.seek(0) - return img_byte_arr - - -def tensor_to_bytesio( - image: torch.Tensor, - name: Optional[str] = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> BytesIO: - """Converts a torch.Tensor image to a named BytesIO object. - - Args: - image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Named BytesIO object containing the image data, with pointer set to the start of buffer. - """ - if not mime_type: - mime_type = "image/png" - - pil_image = _tensor_to_pil(image, total_pixels=total_pixels) - img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = ( - f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" - ) - return img_binary - - -def tensor_to_base64_string( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Base64 encoded string of the image. - """ - pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) - img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_bytes = img_byte_arr.getvalue() - # Encode bytes to base64 string - base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") - return base64_encoded_string - - -def tensor_to_data_uri( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Converts a tensor image to a Data URI string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). - - Returns: - Data URI string (e.g., 'data:image/png;base64,...'). - """ - base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) - return f"data:{mime_type};base64,{base64_string}" - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -392,7 +222,7 @@ def video_to_base64_string( container_format: Optional container format to use (defaults to video.container if available) codec: Optional codec to use (defaults to video.codec if available) """ - video_bytes_io = io.BytesIO() + video_bytes_io = BytesIO() # Use provided format/codec if specified, otherwise use video's own if available format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) @@ -403,214 +233,6 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -async def upload_video_to_comfyapi( - video: VideoInput, - auth_kwargs: Optional[dict[str, str]] = None, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - - Args: - video: VideoInput object (Comfy VIDEO type). - auth_kwargs: Optional authentication token(s). - container: The video container format to use (default: MP4). - codec: The video codec to use (default: H264). - max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. - - Returns: - The download URL for the uploaded video file. - """ - if max_duration is not None: - try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: - raise ValueError( - f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." - ) - except Exception as e: - logging.error("Error getting video duration: %s", str(e)) - raise ValueError(f"Could not verify video duration from source: {e}") from e - - upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=container, codec=codec) - video_bytes_io.seek(0) - - return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) - - -def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: - """ - Prepares audio waveform for av library by converting to a contiguous numpy array. - - Args: - waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. - - Returns: - Contiguous numpy array of the audio waveform. If the audio was batched, - the first item is taken. - """ - if waveform.ndim != 3 or waveform.shape[0] != 1: - raise ValueError("Expected waveform tensor shape (1, channels, samples)") - - # If batch is > 1, take first item - if waveform.shape[0] > 1: - waveform = waveform[0] - - # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array - audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() - if audio_data_np.dtype != np.float32: - audio_data_np = audio_data_np.astype(np.float32) - - return audio_data_np - - -def audio_ndarray_to_bytesio( - audio_data_np: np.ndarray, - sample_rate: int, - container_format: str = "mp4", - codec_name: str = "aac", -) -> BytesIO: - """ - Encodes a numpy array of audio data into a BytesIO object. - """ - audio_bytes_io = io.BytesIO() - with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: - audio_stream = output_container.add_stream(codec_name, rate=sample_rate) - frame = av.AudioFrame.from_ndarray( - audio_data_np, - format="fltp", - layout="stereo" if audio_data_np.shape[0] > 1 else "mono", - ) - frame.sample_rate = sample_rate - frame.pts = 0 - - for packet in audio_stream.encode(frame): - output_container.mux(packet) - - # Flush stream - for packet in audio_stream.encode(None): - output_container.mux(packet) - - audio_bytes_io.seek(0) - return audio_bytes_io - - -async def upload_audio_to_comfyapi( - audio: AudioInput, - auth_kwargs: Optional[dict[str, str]] = None, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", -) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - - Args: - audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded audio file. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - - return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) - - -def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" - if wav.dtype.is_floating_point: - return wav - elif wav.dtype == torch.int16: - return wav.float() / (2 ** 15) - elif wav.dtype == torch.int32: - return wav.float() / (2 ** 31) - raise ValueError(f"Unsupported wav dtype: {wav.dtype}") - - -def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: - """ - Decode any common audio container from bytes using PyAV and return - a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. - """ - with av.open(io.BytesIO(audio_bytes)) as af: - if not af.streams.audio: - raise ValueError("No audio stream found in response.") - stream = af.streams.audio[0] - - in_sr = int(stream.codec_context.sample_rate) - out_sr = in_sr - - frames: list[torch.Tensor] = [] - n_channels = stream.channels or 1 - - for frame in af.decode(streams=stream.index): - arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] - buf = torch.from_numpy(arr) - if buf.ndim == 1: - buf = buf.unsqueeze(0) # [T] -> [1, T] - elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: - buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] - elif buf.shape[0] != n_channels: - buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] - frames.append(buf) - - if not frames: - raise ValueError("Decoded zero audio frames.") - - wav = torch.cat(frames, dim=1) # [C, T] - wav = f32_pcm(wav) - return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} - - -def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: - waveform = audio["waveform"].cpu() - - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format="mp3") - - out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) - out_stream.bit_rate = 320000 - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') - frame.sample_rate = audio["sample_rate"] - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - output_container.mux(out_stream.encode(None)) - output_container.close() - output_buffer.seek(0) - return output_buffer - - -def audio_to_base64_string( - audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" -) -> str: - """Converts an audio input to a base64 string.""" - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - audio_bytes = audio_bytes_io.getvalue() - return base64.b64encode(audio_bytes).decode("utf-8") - - async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, @@ -663,56 +285,3 @@ def resize_mask_to_image( if not allow_gradient: mask = (mask > 0.5).float() return mask - - -def validate_string( - string: str, - strip_whitespace=True, - field_name="prompt", - min_length=None, - max_length=None, -): - if string is None: - raise Exception(f"Field '{field_name}' cannot be empty.") - if strip_whitespace: - string = string.strip() - if min_length and len(string) < min_length: - raise Exception( - f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." - ) - if max_length and len(string) > max_length: - raise Exception( - f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." - ) - - -def image_tensor_pair_to_batch( - image1: torch.Tensor, image2: torch.Tensor -) -> torch.Tensor: - """ - Converts a pair of image tensors to a batch tensor. - If the images are not the same size, the smaller image is resized to - match the larger image. - """ - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale( - image2.movedim(-1, 1), - image1.shape[2], - image1.shape[1], - "bilinear", - "center", - ).movedim(1, -1) - return torch.cat((image1, image2), dim=0) - - -def get_size(path_or_object: Union[str, io.BytesIO]) -> int: - if isinstance(path_or_object, str): - return os.path.getsize(path_or_object) - return len(path_or_object.getvalue()) - - -def validate_container_format_is_mp4(video: VideoInput) -> None: - """Validates video container format is MP4.""" - container_format = video.get_container_format() - if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0e90aef7c..0fc8c0607 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel): mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') -class BFLFluxCannyImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') - canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - -class BFLFluxDepthImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - class BFLFluxProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( @@ -160,15 +122,8 @@ class BFLStatus(str, Enum): error = "Error" -class BFLFluxProStatusResponse(BaseModel): +class BFLFluxStatusResponse(BaseModel): id: str = Field(..., description="The unique identifier for the generation task.") status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field( - None, description="The result of the task (null if not completed)." - ) - progress: confloat(ge=0.0, le=1.0) = Field( - ..., description="The progress of the task (0.0 to 1.0)." - ) - details: Optional[Dict[str, Any]] = Field( - None, description="Additional details about the task (null if not available)." - ) + result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") + progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py new file mode 100644 index 000000000..a55137afb --- /dev/null +++ b/comfy_api_nodes/apis/veo_api.py @@ -0,0 +1,111 @@ +from typing import Optional, Union +from enum import Enum + +from pydantic import BaseModel, Field + + +class Image2(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image3(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance1(BaseModel): + image: Optional[Union[Image2, Image3]] = Field( + None, description='Optional image to guide video generation' + ) + prompt: str = Field(..., description='Text description of the video') + + +class PersonGeneration1(str, Enum): + ALLOW = 'ALLOW' + BLOCK = 'BLOCK' + + +class Parameters1(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + + +class VeoGenVidRequest(BaseModel): + instances: Optional[list[Instance1]] = None + parameters: Optional[Parameters1] = None + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[list[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[list[Video]] = None + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index b6cc90f05..baa74fd52 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,136 +1,43 @@ -import asyncio -import io from inspect import cleandoc -from typing import Union, Optional +from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, + validate_aspect_ratio, +) from comfy_api_nodes.apis.bfl_api import ( - BFLStatus, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, - BFLFluxCannyImageRequest, - BFLFluxDepthImageRequest, - BFLFluxProGenerateRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProUltraGenerateRequest, + BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, + BFLFluxProUltraGenerateRequest, + BFLFluxStatusResponse, + BFLStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - validate_aspect_ratio, - process_image_response, - resize_mask_to_image, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_base64_string, validate_string, ) -import numpy as np -from PIL import Image -import aiohttp -import torch -import base64 -import time -from server import PromptServer - def convert_mask_to_image(mask: torch.Tensor): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ mask = mask.unsqueeze(-1) - mask = torch.cat([mask]*3, dim=-1) + mask = torch.cat([mask] * 3, dim=-1) return mask -async def handle_bfl_synchronous_operation( - operation: SynchronousOperation, - timeout_bfl_calls=360, - node_id: Union[str, None] = None, -): - response_api: BFLFluxProGenerateResponse = await operation.execute() - return await _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id - ) - - -async def _poll_until_generated( - polling_url: str, timeout=360, node_id: Union[str, None] = None -): - # used bfl-comfy-nodes to verify code implementation: - # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main - start_time = time.time() - retries_404 = 0 - max_retries_404 = 5 - retry_404_seconds = 2 - retry_202_seconds = 2 - retry_pending_seconds = 1 - - async with aiohttp.ClientSession() as session: - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - - async with session.get(polling_url) as response: - if response.status == 200: - result = await response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - async with session.get(img_url) as img_resp: - return process_image_response(await img_resp.content.read()) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." - ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - await asyncio.sleep(retry_pending_seconds) - continue - elif response.status == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - await asyncio.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status == 202: - await asyncio.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") - -def convert_image_to_base64(image: torch.Tensor): - scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) - # remove batch dimension if present - if len(scaled_image.shape) > 3: - scaled_image = scaled_image[0] - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return base64.b64encode(img_byte_arr.getvalue()).decode() - - class FluxProUltraImageNode(IO.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. @@ -158,7 +65,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "seed", @@ -220,22 +129,19 @@ class FluxProUltraImageNode(IO.ComfyNode): cls, prompt: str, aspect_ratio: str, - prompt_upsampling=False, - raw=False, - seed=0, - image_prompt=None, - image_prompt_strength=0.1, + prompt_upsampling: bool = False, + raw: bool = False, + seed: int = 0, + image_prompt: Optional[torch.Tensor] = None, + image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1-ultra/generate", - method=HttpMethod.POST, - request_model=BFLFluxProUltraGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProUltraGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProUltraGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, @@ -247,22 +153,26 @@ class FluxProUltraImageNode(IO.ComfyNode): maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ), raw=raw, - image_prompt=( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ), - image_prompt_strength=( - None if image_prompt is None else round(image_prompt_strength, 2) - ), + image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), + image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextProImageNode(IO.ComfyNode): @@ -347,7 +257,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor]=None, + input_image: Optional[torch.Tensor] = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -360,33 +270,36 @@ class FluxKontextProImageNode(IO.ComfyNode): ) if input_image is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=cls.BFL_PATH, - method=HttpMethod.POST, - request_model=BFLFluxKontextProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxKontextProGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=cls.BFL_PATH, method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxKontextProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, guidance=round(guidance, 1), steps=steps, seed=seed, aspect_ratio=aspect_ratio, - input_image=( - input_image - if input_image is None - else convert_image_to_base64(input_image) - ) + input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -422,7 +335,9 @@ class FluxProImageNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "width", @@ -481,20 +396,15 @@ class FluxProImageNode(IO.ComfyNode): image_prompt=None, # image_prompt_strength=0.1, ) -> IO.NodeOutput: - image_prompt = ( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( + image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) + initial_response = await sync_op( + cls, + ApiEndpoint( path="/proxy/bfl/flux-pro-1.1/generate", - method=HttpMethod.POST, - request_model=BFLFluxProGenerateRequest, - response_model=BFLFluxProGenerateResponse, + method="POST", ), - request=BFLFluxProGenerateRequest( + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, width=width, @@ -502,13 +412,23 @@ class FluxProImageNode(IO.ComfyNode): seed=seed, image_prompt=image_prompt, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProExpandNode(IO.ComfyNode): @@ -534,7 +454,9 @@ class FluxProExpandNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "top", @@ -610,16 +532,11 @@ class FluxProExpandNode(IO.ComfyNode): guidance: float, seed=0, ) -> IO.NodeOutput: - image = convert_image_to_base64(image) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-expand/generate", - method=HttpMethod.POST, - request_model=BFLFluxExpandImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxExpandImageRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxExpandImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, top=top, @@ -629,16 +546,25 @@ class FluxProExpandNode(IO.ComfyNode): steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProFillNode(IO.ComfyNode): @@ -665,7 +591,9 @@ class FluxProFillNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Float.Input( "guidance", @@ -712,272 +640,37 @@ class FluxProFillNode(IO.ComfyNode): ) -> IO.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) - mask = convert_image_to_base64(convert_mask_to_image(mask)) - # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:, :, :, :3]) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-fill/generate", - method=HttpMethod.POST, - request_model=BFLFluxFillImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxFillImageRequest( + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxFillImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProCannyNode(IO.ComfyNode): - """ - Generate image using a control image (canny). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProCannyNode", - display_name="Flux.1 Canny Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Float.Input( - "canny_low_threshold", - default=0.1, - min=0.01, - max=0.99, - step=0.01, - tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Float.Input( - "canny_high_threshold", - default=0.4, - min=0.01, - max=0.99, - step=0.01, - tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=30, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, + queued_statuses=[], ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - canny_low_threshold: float, - canny_high_threshold: float, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:, :, :, :3]) - preprocessed_image = None - - # scale canny threshold between 0-500, to match BFL's API - def scale_value(value: float, min_val=0, max_val=500): - return min_val + value * (max_val - min_val) - canny_low_threshold = int(round(scale_value(canny_low_threshold))) - canny_high_threshold = int(round(scale_value(canny_high_threshold))) - - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - canny_low_threshold = None - canny_high_threshold = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-canny/generate", - method=HttpMethod.POST, - request_model=BFLFluxCannyImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxCannyImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - canny_low_threshold=canny_low_threshold, - canny_high_threshold=canny_high_threshold, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProDepthNode(IO.ComfyNode): - """ - Generate image using a control image (depth). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProDepthNode", - display_name="Flux.1 Depth Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=15, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:,:,:,:3]) - preprocessed_image = None - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-depth/generate", - method=HttpMethod.POST, - request_model=BFLFluxDepthImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxDepthImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class BFLExtension(ComfyExtension): @@ -990,8 +683,6 @@ class BFLExtension(ComfyExtension): FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, - FluxProCannyNode, - FluxProDepthNode, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f3d3f8d3e..534af380d 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,35 +1,27 @@ import logging import math from enum import Enum -from typing import Literal, Optional, Type, Union -from typing_extensions import override +from typing import Literal, Optional, Union import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_image_aspect_ratio_range, - get_number_of_images, - validate_image_dimensions, -) -from comfy_api_nodes.apis.client import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - SynchronousOperation, - PollingOperation, - T, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, - upload_images_to_comfyapi, - validate_string, + get_number_of_images, image_tensor_pair_to_batch, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio_range, + validate_image_dimensions, + validate_string, ) - BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" # Long-running tasks endpoints(e.g., video) @@ -46,13 +38,14 @@ class Image2ImageModelName(str, Enum): class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-t2v-250428" class Image2VideoModelName(str, Enum): """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - seedance_1_pro = "seedance-1-0-pro-250528" + + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-i2v-250428" @@ -208,35 +201,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - task_id: str, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - completed_statuses=[ - "succeeded", - ], - failed_statuses=[ - "cancelled", - "failed", - ], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -303,7 +267,7 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -341,8 +305,7 @@ class ByteDanceImageNode(IO.ComfyNode): w, h = width, height if not (512 <= w <= 2048) or not (512 <= h <= 2048): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 512 and 2048 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels." ) payload = Text2ImageTaskCreationRequest( @@ -353,20 +316,12 @@ class ByteDanceImageNode(IO.ComfyNode): guidance_scale=guidance_scale, watermark=watermark, ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Text2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -420,7 +375,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -449,16 +404,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - source_url = (await upload_images_to_comfyapi( - image, - max_images=1, - mime_type="image/png", - auth_kwargs=auth_kwargs, - ))[0] + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, prompt=prompt, @@ -467,16 +413,12 @@ class ByteDanceImageEditNode(IO.ComfyNode): guidance_scale=guidance_scale, watermark=watermark, ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Image2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -504,7 +446,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "List of 1-10 images for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -534,9 +476,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "sequential_image_generation", options=["disabled", "auto"], tooltip="Group image generation mode. " - "'disabled' generates a single image. " - "'auto' lets the model decide whether to generate multiple related images " - "(e.g., story scenes, character variations).", + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", optional=True, ), IO.Int.Input( @@ -547,7 +489,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): step=1, display_mode=IO.NumberDisplay.number, tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " - "Total images (input + generated) cannot exceed 15.", + "Total images (input + generated) cannot exceed 15.", optional=True, ), IO.Int.Input( @@ -564,7 +506,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image.", + tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), IO.Boolean.Input( @@ -611,8 +553,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): w, h = width, height if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 1024 and 4096 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: @@ -621,41 +562,31 @@ class ByteDanceSeedreamNode(IO.ComfyNode): raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } reference_images_urls = [] if n_input_images: for i in image: validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) - reference_images_urls = (await upload_images_to_comfyapi( + reference_images_urls = await upload_images_to_comfyapi( + cls, image, max_images=n_input_images, mime_type="image/png", - auth_kwargs=auth_kwargs, - )) - payload = Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Seedream4TaskCreationRequest, - response_model=ImageTaskCreationResponse, + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - + ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] @@ -719,13 +650,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -764,19 +695,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): f"--camerafixed {str(camera_fixed).lower()} " f"--watermark {str(watermark).lower()}" ) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } return await process_video_task( - request_model=Text2VideoTaskCreationRequest, - payload=Text2VideoTaskCreationRequest( - model=model, - content=[TaskTextContent(text=prompt)], - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -840,13 +761,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -879,13 +800,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] - + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -897,13 +812,11 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -971,13 +884,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1012,16 +925,11 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) prompt = ( @@ -1035,7 +943,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[ @@ -1044,8 +952,6 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1108,7 +1014,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1141,15 +1047,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_urls = await upload_images_to_comfyapi( - images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs - ) - + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -1160,42 +1058,32 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ) x = [ TaskTextContent(text=prompt), - *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls], ] return await process_video_task( - request_model=Image2VideoTaskCreationRequest, - payload=Image2VideoTaskCreationRequest( - model=model, - content=x, - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) async def process_video_task( - request_model: Type[T], + cls: type[IO.ComfyNode], payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - auth_kwargs: dict, - node_id: str, estimated_duration: Optional[int], ) -> IO.NodeOutput: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_TASK_ENDPOINT, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - response = await poll_until_finished( - auth_kwargs, - initial_response.id, + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - node_id=node_id, + response_model=TaskStatusResponse, ) return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) @@ -1221,5 +1109,6 @@ class ByteDanceExtension(ComfyExtension): ByteDanceImageReferenceNode, ] + async def comfy_entrypoint() -> ByteDanceExtension: return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index c1941cbe9..ca11b67ed 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -33,12 +33,9 @@ from comfy_api_nodes.apis.client import ( SynchronousOperation, ) from comfy_api_nodes.apinode_utils import ( - validate_string, - audio_to_base64_string, video_to_base64_string, - tensor_to_base64_string, - bytesio_to_image_tensor, ) +from comfy_api_nodes.util import validate_string, tensor_to_base64_string, bytesio_to_image_tensor, audio_to_base64_string from comfy_api.util import VideoContainer, VideoCodec diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 67c8307c5..eea65c9ac 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -5,8 +5,7 @@ For source of truth on the allowed permutations of request fields, please refere """ from __future__ import annotations -from typing import Optional, TypeVar, Any -from collections.abc import Callable +from typing import Optional, TypeVar import math import logging @@ -15,7 +14,6 @@ from typing_extensions import override import torch from comfy_api_nodes.apis import ( - KlingTaskStatus, KlingCameraControl, KlingCameraConfig, KlingCameraControlType, @@ -52,26 +50,20 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_base64_string, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - validate_string, -) -from comfy_api_nodes.util.validation_utils import ( +from comfy_api_nodes.util import ( validate_image_dimensions, validate_image_aspect_ratio, validate_video_dimensions, validate_video_duration, + tensor_to_base64_string, + validate_string, + upload_audio_to_comfyapi, + download_url_to_image_tensor, + upload_video_to_comfyapi, + download_url_to_video_output, + sync_op, + ApiEndpoint, + poll_op, ) from comfy_api.input_impl import VideoFromFile from comfy_api.input.basic_types import AudioInput @@ -214,34 +206,6 @@ VOICES_CONFIG = { } -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - KlingTaskStatus.succeed.value, - ], - failed_statuses=[KlingTaskStatus.failed.value], - status_extractor=lambda response: ( - response.data.task_status.value - if response.data and response.data.task_status - else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -377,8 +341,7 @@ async def image_result_to_node_output( async def execute_text2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], prompt: str, negative_prompt: str, cfg_scale: float, @@ -389,14 +352,11 @@ async def execute_text2video( camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=KlingText2VideoResponse, + data=KlingText2VideoRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, duration=KlingVideoGenDuration(duration), @@ -406,24 +366,17 @@ async def execute_text2video( aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), + response_model=KlingText2VideoResponse, estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -432,8 +385,7 @@ async def execute_text2video( async def execute_image2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -455,14 +407,11 @@ async def execute_image2video( if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: model_mode = "pro" # October 5: currently "std" mode is not supported for this model - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=KlingImage2VideoResponse, + data=KlingImage2VideoRequest( model_name=KlingVideoGenModelName(model_name), image=tensor_to_base64_string(start_frame), image_tail=( @@ -477,24 +426,17 @@ async def execute_image2video( duration=KlingVideoGenDuration(duration), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -503,8 +445,7 @@ async def execute_image2video( async def execute_video_effect( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, model_name: str, @@ -530,35 +471,25 @@ async def execute_video_effect( duration=duration, ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), + response_model=KlingVideoEffectsResponse, + data=KlingVideoEffectsRequest( effect_scene=effect_scene, input=request_input_field, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), + response_model=KlingVideoEffectsResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -567,8 +498,7 @@ async def execute_video_effect( async def execute_lipsync( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], video: VideoInput, audio: Optional[AudioInput] = None, voice_language: Optional[str] = None, @@ -583,24 +513,21 @@ async def execute_lipsync( validate_video_duration(video, 2, 10) # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) + video_url = await upload_video_to_comfyapi(cls, video) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) + audio_url = await upload_audio_to_comfyapi(cls, audio) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(PATH_LIP_SYNC, "POST"), + response_model=KlingLipSyncResponse, + data=KlingLipSyncRequest( input=KlingLipSyncInputObject( video_url=video_url, mode=model_mode, @@ -612,24 +539,17 @@ async def execute_lipsync( voice_id=voice_id, ), ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), + response_model=KlingLipSyncResponse, estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -807,11 +727,7 @@ class KlingTextToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, @@ -872,11 +788,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, model_mode=KlingVideoGenMode.std, @@ -944,11 +856,7 @@ class KlingImage2VideoNode(IO.ComfyNode): end_frame: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, start_frame=start_frame, prompt=prompt, negative_prompt=negative_prompt, @@ -1017,11 +925,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): camera_control: KlingCameraControl, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -1097,11 +1001,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): ) -> IO.NodeOutput: mode, duration, model_name = MODE_START_END_FRAME[mode] return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -1162,41 +1062,27 @@ class KlingVideoExtendNode(IO.ComfyNode): video_id: str, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EXTEND, - method=HttpMethod.POST, - request_model=KlingVideoExtendRequest, - response_model=KlingVideoExtendResponse, - ), - request=KlingVideoExtendRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), + response_model=KlingVideoExtendResponse, + data=KlingVideoExtendRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), + response_model=KlingVideoExtendResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -1259,11 +1145,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): duration: KlingVideoGenDuration, ) -> IO.NodeOutput: video, _, duration = await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1324,11 +1206,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.NodeOutput( *( await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1379,11 +1257,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, audio=audio, voice_language=voice_language, @@ -1445,11 +1319,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: voice_id, voice_language = VOICES_CONFIG[voice] return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, text=text, voice_language=voice_language, @@ -1496,40 +1366,26 @@ class KlingVirtualTryOnNode(IO.ComfyNode): cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIRTUAL_TRY_ON, - method=HttpMethod.POST, - request_model=KlingVirtualTryOnRequest, - response_model=KlingVirtualTryOnResponse, - ), - request=KlingVirtualTryOnRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), + response_model=KlingVirtualTryOnResponse, + data=KlingVirtualTryOnRequest( human_image=tensor_to_base64_string(human_image), cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), + response_model=KlingVirtualTryOnResponse, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) @@ -1625,18 +1481,11 @@ class KlingImageGenerationNode(IO.ComfyNode): else: image = tensor_to_base64_string(image) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_GENERATIONS, - method=HttpMethod.POST, - request_model=KlingImageGenerationsRequest, - response_model=KlingImageGenerationsResponse, - ), - request=KlingImageGenerationsRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), + response_model=KlingImageGenerationsResponse, + data=KlingImageGenerationsRequest( model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, @@ -1647,24 +1496,17 @@ class KlingImageGenerationNode(IO.ComfyNode): n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), + response_model=KlingImageGenerationsResponse, estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 610d95a77..e74441e5e 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -35,9 +35,9 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, process_image_response, - validate_string, ) from server import PromptServer +from comfy_api_nodes.util import validate_string import aiohttp import torch diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 23be1ae65..e3722e79b 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -24,8 +24,8 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( download_url_to_bytesio, upload_images_to_comfyapi, - validate_string, ) +from comfy_api_nodes.util import validate_string from server import PromptServer diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7566188dd..7c31d95b3 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,35 +1,31 @@ import logging -from typing import Any, Callable, Optional, TypeVar +from typing import Optional + import torch from typing_extensions import override -from comfy_api_nodes.util.validation_utils import validate_image_dimensions +from comfy_api.input import VideoInput +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis import ( - MoonvalleyTextToVideoRequest, + MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, + MoonvalleyTextToVideoRequest, MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoRequest, - MoonvalleyPromptResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_video_output, + poll_op, + sync_op, + trim_video, upload_images_to_comfyapi, upload_video_to_comfyapi, validate_container_format_is_mp4, + validate_image_dimensions, + validate_string, ) -from comfy_api.input import VideoInput -from comfy_api.latest import ComfyExtension, InputImpl, IO -import av -import io - API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" @@ -51,13 +47,6 @@ MAX_VID_HEIGHT = 10000 MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 -R = TypeVar("R") - - -class MoonvalleyApiError(Exception): - """Base exception for Moonvalley API errors.""" - - pass def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: @@ -69,64 +58,7 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" logging.error(error_msg) - raise MoonvalleyApiError(error_msg) - - -def get_video_from_response(response): - video = response.output_url - logging.info( - "Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video - ) - return video - - -def get_video_url_from_response(response) -> Optional[str]: - """Returns the first video url from the Moonvalley video generation task result. - Will not raise an error if the response is not valid. - """ - if response: - return str(get_video_from_response(response)) - else: - return None - - -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - "completed", - ], - max_poll_attempts=240, # 64 minutes with 16s interval - poll_interval=16.0, - failed_statuses=["error"], - status_extractor=lambda response: ( - response.status if response and response.status else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - node_id=node_id, - ).execute() - - -def validate_prompts( - prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH -): - """Verifies that the prompt isn't empty and that neither prompt is too long.""" - if not prompt: - raise ValueError("Positive prompt is empty") - if len(prompt) > max_length: - raise ValueError(f"Positive prompt is too long: {len(prompt)} characters") - if negative_prompt and len(negative_prompt) > max_length: - raise ValueError( - f"Negative prompt is too long: {len(negative_prompt)} characters" - ) - return True + raise RuntimeError(error_msg) def validate_video_to_video_input(video: VideoInput) -> VideoInput: @@ -170,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None: } if (width, height) not in supported_resolutions: - supported_list = ", ".join( - [f"{w}x{h}" for w, h in sorted(supported_resolutions)] - ) - raise ValueError( - f"Resolution {width}x{height} not supported. Supported: {supported_list}" - ) + supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -188,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_minimum_duration(duration: float) -> None: """Ensures video is at least 5 seconds long.""" if duration < 5: - raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + raise ValueError("Input video must be at least 5 seconds long.") def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: @@ -198,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video -def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: - """ - Returns a new VideoInput object trimmed from the beginning to the specified duration, - using av to avoid loading entire video into memory. - - Args: - video: Input video to trim - duration_sec: Duration in seconds to keep from the beginning - - Returns: - VideoFromFile object that owns the output buffer - """ - output_buffer = io.BytesIO() - - input_container = None - output_container = None - - try: - # Get the stream source - this avoids loading entire video into memory - # when the source is already a file path - input_source = video.get_stream_source() - - # Open containers - input_container = av.open(input_source, mode="r") - output_container = av.open(output_buffer, mode="w", format="mp4") - - # Set up output streams for re-encoding - video_stream = None - audio_stream = None - - for stream in input_container.streams: - logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) - if isinstance(stream, av.VideoStream): - # Create output video stream with same parameters - video_stream = output_container.add_stream( - "h264", rate=stream.average_rate - ) - video_stream.width = stream.width - video_stream.height = stream.height - video_stream.pix_fmt = "yuv420p" - logging.info( - "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate - ) - elif isinstance(stream, av.AudioStream): - # Create output audio stream with same parameters - audio_stream = output_container.add_stream( - "aac", rate=stream.sample_rate - ) - audio_stream.sample_rate = stream.sample_rate - audio_stream.layout = stream.layout - logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) - - # Calculate target frame count that's divisible by 16 - fps = input_container.streams.video[0].average_rate - estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 16 - ) * 16 # Round down to nearest multiple of 16 - - if target_frames == 0: - raise ValueError("Video too short: need at least 16 frames for Moonvalley") - - frame_count = 0 - audio_frame_count = 0 - - # Decode and re-encode video frames - if video_stream: - for frame in input_container.decode(video=0): - if frame_count >= target_frames: - break - - # Re-encode frame - for packet in video_stream.encode(frame): - output_container.mux(packet) - frame_count += 1 - - # Flush encoder - for packet in video_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) - - # Decode and re-encode audio frames - if audio_stream: - input_container.seek(0) # Reset to beginning for audio - for frame in input_container.decode(audio=0): - if frame.time >= duration_sec: - break - - # Re-encode frame - for packet in audio_stream.encode(frame): - output_container.mux(packet) - audio_frame_count += 1 - - # Flush encoder - for packet in audio_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s audio frames", audio_frame_count) - - # Close containers - output_container.close() - input_container.close() - - # Return as VideoFromFile using the buffer - output_buffer.seek(0) - return InputImpl.VideoFromFile(output_buffer) - - except Exception as e: - # Clean up on error - if input_container is not None: - input_container.close() - if output_container is not None: - output_container.close() - raise RuntimeError(f"Failed to trim video: {str(e)}") from e - - def parse_width_height_from_res(resolution: str): # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict res_map = { @@ -338,19 +149,14 @@ def parse_control_parameter(value): return control_map.get(value, control_map["Motion Transfer"]) -async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None -) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, +async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: + return await poll_op( + cls, + ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), + response_model=MoonvalleyPromptResponse, + status_extractor=lambda r: (r.status if r and r.status else None), + poll_interval=16.0, + max_poll_attempts=240, ) @@ -444,14 +250,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): steps: int, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -464,33 +266,17 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - - image_url = ( - await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth, mime_type=mime_type - ) - )[0] - - request = MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_IMG2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest( + image_url=image_url, prompt_text=prompt, inference_params=inference_params ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) + final_response = await get_response(cls, task_creation_response.id) video = await download_url_to_video_output(final_response.output_url) return IO.NodeOutput(video) @@ -582,15 +368,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) - - validate_prompts(prompt, negative_prompt) + video_url = await upload_video_to_comfyapi(cls, validated_video) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) # Only include motion_intensity for Motion Transfer control_params = {} @@ -605,35 +386,20 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): guidance_scale=prompt_adherence, ) - control = parse_control_parameter(control_type) - - request = MoonvalleyVideoToVideoRequest( - control_type=control, - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_VIDEO2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyVideoToVideoRequest, - response_model=MoonvalleyPromptResponse, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyVideoToVideoRequest( + control_type=parse_control_parameter(control_type), + video_url=video_url, + prompt_text=prompt, + inference_params=inference_params, ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyTxt2VideoNode(IO.ComfyNode): @@ -720,14 +486,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): seed: int, steps: int, ) -> IO.NodeOutput: - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -737,30 +499,16 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): width=width_height["width"], height=width_height["height"], ) - request = MoonvalleyTextToVideoRequest( - prompt_text=prompt, inference_params=inference_params - ) - init_op = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_TXT2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, - ), - request=request, - auth_kwargs=auth, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), ) - task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index e3b81de75..c467e840c 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -43,13 +43,11 @@ from comfy_api_nodes.apis.client import ( ) from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, validate_and_cast_response, - validate_string, - tensor_to_base64_string, text_filepath_to_data_uri, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api_nodes.util import downscale_image_tensor, validate_string, tensor_to_base64_string RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 27cb0067b..5bb406a3b 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -14,11 +14,6 @@ import torch from typing_extensions import override from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apinode_utils import ( - download_url_to_video_output, - tensor_to_bytesio, - validate_string, -) from comfy_api_nodes.apis import pika_defs from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -27,6 +22,7 @@ from comfy_api_nodes.apis.client import ( PollingOperation, SynchronousOperation, ) +from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio R = TypeVar("R") diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 438a7f80b..b2b841be8 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -24,10 +24,7 @@ from comfy_api_nodes.apis.client import ( PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( - tensor_to_bytesio, - validate_string, -) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO @@ -50,7 +47,6 @@ def get_video_url_from_response( async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call - files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/image/upload", @@ -59,16 +55,14 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): response_model=PixverseImageUploadResponse, ), request=EmptyRequest(), - files=files, + files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: - raise Exception( - f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" - ) + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") return response_upload.Resp.img_id @@ -95,7 +89,6 @@ class PixverseTemplateNode(IO.ComfyNode): template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") - # just return the integer return IO.NodeOutput(template_id) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8beed5675..8ee7e55c4 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -24,12 +24,10 @@ from comfy_api_nodes.apis.client import ( EmptyRequest, ) from comfy_api_nodes.apinode_utils import ( - bytesio_to_image_tensor, download_url_to_bytesio, - tensor_to_bytesio, resize_mask_to_image, - validate_string, ) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor from server import PromptServer import torch diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index eb03a897d..0543d1d0e 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,7 +11,7 @@ User Guides: """ -from typing import Union, Optional, Any +from typing import Union, Optional from typing_extensions import override from enum import Enum @@ -21,7 +21,6 @@ from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, - RunwayTaskStatusEnum as TaskStatus, RunwayModelEnum as Model, RunwayDurationEnum as Duration, RunwayAspectRatioEnum as AspectRatio, @@ -33,23 +32,20 @@ from comfy_api_nodes.apis import ( ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - upload_images_to_comfyapi, - download_url_to_video_output, +from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, + validate_image_dimensions, + validate_image_aspect_ratio, + upload_images_to_comfyapi, + download_url_to_video_output, download_url_to_image_tensor, + ApiEndpoint, + sync_op, + poll_op, ) from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, TaskStatusResponse], - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], - status_extractor=lambda response: response.status.value, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - progress_extractor=extract_progress_from_task_status, - ).execute() - - def extract_progress_from_task_status( response: TaskStatusResponse, ) -> Union[float, None]: @@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), + return await poll_op( + cls, + ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status.value, estimated_duration=estimated_duration, - node_id=node_id, + progress_extractor=extract_progress_from_task_status, ) async def generate_video( + cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, estimated_duration: Optional[int] = None, ) -> VideoFromFile: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, - auth_kwargs=auth_kwargs, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=RunwayImageToVideoResponse, + data=request, ) - initial_response = await initial_operation.execute() - - final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: raise RunwayApiError("Runway task succeeded but no video data found in response.") @@ -184,9 +145,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): display_name="Runway Image to Video (Gen3a Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", inputs=[ IO.String.Input( "prompt", @@ -241,20 +202,16 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -262,15 +219,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, ) ) @@ -284,9 +235,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): display_name="Runway Image to Video (Gen4 Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", inputs=[ IO.String.Input( "prompt", @@ -341,20 +292,16 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -362,15 +309,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -385,12 +326,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): display_name="Runway First-Last-Frame to Video", category="api node/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " - "More complex transitions, such as cases where the Last frame is completely different " - "from the First frame, may benefit from the longer 10s duration. " - "This would give the generation more time to smoothly transition between the two inputs. " - "Before diving in, review these best practices to ensure that your input selections " - "will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", inputs=[ IO.String.Input( "prompt", @@ -452,23 +393,19 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( + cls, stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -477,17 +414,11 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), + RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), + RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -502,7 +433,7 @@ class RunwayTextToImageNode(IO.ComfyNode): display_name="Runway Text to Image", category="api node/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " - "You can also include reference image to guide the generation.", + "You can also include reference image to guide the generation.", inputs=[ IO.String.Input( "prompt", @@ -540,49 +471,34 @@ class RunwayTextToImageNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Prepare reference images if provided reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) download_urls = await upload_images_to_comfyapi( + cls, reference_image, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) reference_images = [ReferenceImage(uri=str(download_urls[0]))] - request = RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_IMAGE, - method=HttpMethod.POST, - request_model=RunwayTextToImageRequest, - response_model=RunwayTextToImageResponse, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), + response_model=RunwayTextToImageResponse, + data=RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, ), - request=request, - auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - - # Poll for completion final_response = await get_response( + cls, initial_response.id, - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: @@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension): RunwayTextToImageNode, ] + async def comfy_entrypoint() -> RunwayExtension: return RunwayExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index efc954869..92b225d40 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -1,23 +1,20 @@ from typing import Optional -from typing_extensions import override import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_bytesio, ) + class Sora2GenerationRequest(BaseModel): prompt: str = Field(...) model: str = Field(...) @@ -80,7 +77,7 @@ class OpenAIVideoSora2(IO.ComfyNode): control_after_generate=True, optional=True, tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", + "actual results are nondeterministic regardless of seed.", ), ], outputs=[ @@ -111,55 +108,34 @@ class OpenAIVideoSora2(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - payload = Sora2GenerationRequest( - model=model, - prompt=prompt, - seconds=str(duration), - size=size, - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/v1/videos", - method=HttpMethod.POST, - request_model=Sora2GenerationRequest, - response_model=Sora2GenerationResponse + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), + data=Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, ), - request=payload, files=files_input, - auth_kwargs=auth, + response_model=Sora2GenerationResponse, content_type="multipart/form-data", ) - initial_response = await initial_operation.execute() if initial_response.error: - raise Exception(initial_response.error.message) + raise Exception(initial_response.error["message"]) model_time_multiplier = 1 if model == "sora-2" else 2 - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/openai/v1/videos/{initial_response.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=Sora2GenerationResponse - ), - completed_statuses=["completed"], - failed_statuses=["failed"], + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), + response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, - auth_kwargs=auth, poll_interval=8.0, max_poll_attempts=160, - node_id=cls.hidden.unique_id, - estimated_duration=45 * (duration / 4) * model_time_multiplier, + estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) - await poll_operation.execute() return IO.NodeOutput( - await download_url_to_video_output( - f"/proxy/openai/v1/videos/{initial_response.id}/content", - auth_kwargs=auth, - ) + await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), ) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 8af03cfd1..783666ddf 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -27,14 +27,14 @@ from comfy_api_nodes.apis.client import ( PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_audio_duration, + validate_string, + audio_input_to_mp3, bytesio_to_image_tensor, tensor_to_bytesio, - validate_string, audio_bytes_to_audio_input, - audio_input_to_mp3, ) -from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index daeaa823e..d37e9e9b4 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,28 +1,21 @@ -import logging import base64 -import aiohttp -import torch from io import BytesIO -from typing import Optional + from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( - VeoGenVidRequest, - VeoGenVidResponse, +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, + VeoGenVidRequest, + VeoGenVidResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, -) - -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, tensor_to_base64_string, ) @@ -35,28 +28,6 @@ MODELS_MAP = { "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } -def convert_image_to_base64(image: torch.Tensor): - if image is None: - return None - - scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) - return tensor_to_base64_string(scaled_image) - - -def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]: - if ( - poll_response.response - and hasattr(poll_response.response, "videos") - and poll_response.response.videos - and len(poll_response.response.videos) > 0 - ): - video = poll_response.response.videos[0] - else: - return None - if hasattr(video, "gcsUri") and video.gcsUri: - return str(video.gcsUri) - return None - class VeoVideoGenerationNode(IO.ComfyNode): """ @@ -169,18 +140,13 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Prepare the instances for the request instances = [] - instance = { - "prompt": prompt - } + instance = {"prompt": prompt} # Add image if provided if image is not None: - image_base64 = convert_image_to_base64(image) + image_base64 = tensor_to_base64_string(image) if image_base64: - instance["image"] = { - "bytesBase64Encoded": image_base64, - "mimeType": "image/png" - } + instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} instances.append(instance) @@ -198,119 +164,77 @@ class VeoVideoGenerationNode(IO.ComfyNode): if seed > 0: parameters["seed"] = seed # Only add generateAudio for Veo 3 models - if "veo-3.0" in model: + if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Initial request to start video generation - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/generate", - method=HttpMethod.POST, - request_model=VeoGenVidRequest, - response_model=VeoGenVidResponse - ), - request=VeoGenVidRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( instances=instances, - parameters=parameters + parameters=parameters, ), - auth_kwargs=auth, ) - initial_response = await initial_operation.execute() - operation_name = initial_response.name - - logging.info("Veo generation started with operation name: %s", operation_name) - - # Define status extractor function def status_extractor(response): # Only return "completed" if the operation is done, regardless of success or failure # We'll check for errors after polling completes return "completed" if response.done else "pending" - # Define progress extractor function - def progress_extractor(response): - # Could be enhanced if the API provides progress information - return None - - # Define the polling operation - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/poll", - method=HttpMethod.POST, - request_model=VeoGenVidPollRequest, - response_model=VeoGenVidPollResponse - ), - completed_statuses=["completed"], - failed_statuses=[], # No failed statuses, we'll handle errors after polling + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - progress_extractor=progress_extractor, - request=VeoGenVidPollRequest( - operationName=operation_name + data=VeoGenVidPollRequest( + operationName=initial_response.name, ), - auth_kwargs=auth, poll_interval=5.0, - result_url_extractor=get_video_url_from_response, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) - # Execute the polling operation - poll_response = await poll_operation.execute() - # Now check for errors in the final response # Check for error in poll response - if hasattr(poll_response, 'error') and poll_response.error: - error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" - logging.error(error_message) - raise Exception(error_message) + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") # Check for RAI filtered content - if (hasattr(poll_response.response, 'raiMediaFilteredCount') and - poll_response.response.raiMediaFilteredCount > 0): + if ( + hasattr(poll_response.response, "raiMediaFilteredCount") + and poll_response.response.raiMediaFilteredCount > 0 + ): # Extract reason message if available - if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and - poll_response.response.raiMediaFilteredReasons): + if ( + hasattr(poll_response.response, "raiMediaFilteredReasons") + and poll_response.response.raiMediaFilteredReasons + ): reason = poll_response.response.raiMediaFilteredReasons[0] error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" else: error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" - logging.error(error_message) raise Exception(error_message) # Extract video data - if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): video = poll_response.response.videos[0] # Check if video is provided as base64 or URL - if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: - # Decode base64 string to bytes - video_data = base64.b64decode(video.bytesBase64Encoded) - elif hasattr(video, 'gcsUri') and video.gcsUri: - # Download from URL - async with aiohttp.ClientSession() as session: - async with session.get(video.gcsUri) as video_response: - video_data = await video_response.content.read() - else: - raise Exception("Video returned but no data or URL was provided") - else: - raise Exception("Video generation completed but no video was returned") + if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) - if not video_data: - raise Exception("No video data was returned") + if hasattr(video, "gcsUri") and video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) - logging.info("Video generation completed successfully") - - # Convert video data to BytesIO object - video_io = BytesIO(video_data) - - # Return VideoFromFile object - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -394,7 +318,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Combo.Input( "model", options=[ - "veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" + "veo-3.1-generate", + "veo-3.1-fast-generate", + "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", ], default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", @@ -427,5 +354,6 @@ class VeoExtension(ComfyExtension): Veo3VideoGenerationNode, ] + async def comfy_entrypoint() -> VeoExtension: return VeoExtension() diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 639be4b2b..0e0572f8c 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,27 +1,23 @@ import logging from enum import Enum -from typing import Any, Callable, Optional, Literal, TypeVar -from typing_extensions import override +from typing import Literal, Optional, TypeVar import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_aspect_ratio_closeness, - validate_image_dimensions, - validate_image_aspect_ratio_range, - get_number_of_images, -) -from comfy_api_nodes.apis.client import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_aspect_ratio_closeness, + validate_image_aspect_ratio_range, + validate_image_dimensions, ) -from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi - VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" @@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" R = TypeVar("R") + class VideoModelName(str, Enum): - vidu_q1 = 'viduq1' + vidu_q1 = "viduq1" class AspectRatio(str, Enum): @@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel): images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") -class TaskStatus(str, Enum): - created = "created" - queueing = "queueing" - processing = "processing" - success = "success" - failed = "failed" - - class TaskCreationResponse(BaseModel): task_id: str = Field(...) - state: TaskStatus = Field(...) + state: str = Field(...) created_at: str = Field(...) code: Optional[int] = Field(None, description="Error code") @@ -85,32 +74,11 @@ class TaskResult(BaseModel): class TaskStatusResponse(BaseModel): - state: TaskStatus = Field(...) + state: str = Field(...) err_code: Optional[str] = Field(None) creations: list[TaskResult] = Field(..., description="Generated results") -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[TaskStatus.success.value], - failed_statuses=[TaskStatus.failed.value], - status_extractor=lambda response: response.state.value, - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def get_video_url_from_response(response) -> Optional[str]: if response.creations: return response.creations[0].url @@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult: async def execute_task( + cls: type[IO.ComfyNode], vidu_endpoint: str, - auth_kwargs: Optional[dict[str, str]], payload: TaskCreationRequest, estimated_duration: int, - node_id: str, ) -> R: - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=vidu_endpoint, - method=HttpMethod.POST, - request_model=TaskCreationRequest, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - if response.state == TaskStatus.failed: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), + response_model=TaskCreationResponse, + data=payload, + ) + if response.state == "failed": error_msg = f"Vidu request failed. Code: {response.code}" logging.error(error_msg) raise RuntimeError(error_msg) - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=VIDU_GET_GENERATION_STATUS % response.task_id, - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - result_url_extractor=get_video_url_from_response, + return await poll_op( + cls, + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.state.value, estimated_duration=estimated_duration, - node_id=node_id, ) @@ -258,11 +216,7 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -362,17 +316,13 @@ class ViduImageToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -484,17 +434,13 @@ class ViduReferenceVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, images, max_images=7, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -596,15 +542,11 @@ class ViduStartEndToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = [ - (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension): ViduStartEndToVideoNode, ] + async def comfy_entrypoint() -> ViduExtension: return ViduExtension() diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b089bd907..2aab3c2ff 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,28 +1,24 @@ import re -from typing import Optional, Type, Union -from typing_extensions import override +from typing import Optional import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, - R, - T, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_to_base64_string, download_url_to_image_tensor, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_base64_string, - audio_to_base64_string, + validate_audio_duration, ) + class Text2ImageInputField(BaseModel): prompt: str = Field(...) negative_prompt: Optional[str] = Field(None) @@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel): request_id: str = Field(...) -RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') - - -async def process_task( - auth_kwargs: dict[str, str], - url: str, - request_model: Type[T], - response_model: Type[R], - payload: Union[ - Text2ImageTaskCreationRequest, - Image2ImageTaskCreationRequest, - Text2VideoTaskCreationRequest, - Image2VideoTaskCreationRequest, - ], - node_id: str, - estimated_duration: int, - poll_interval: int, -) -> Type[R]: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=url, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - - if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") - - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=response_model, - ), - completed_statuses=["SUCCEEDED"], - failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], - status_extractor=lambda x: x.output.task_status, - estimated_duration=estimated_duration, - poll_interval=poll_interval, - node_id=node_id, - auth_kwargs=auth_kwargs, - ).execute() +RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") class WanTextToImageApi(IO.ComfyNode): @@ -259,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -286,26 +236,28 @@ class WanTextToImageApi(IO.ComfyNode): prompt_extend: bool = True, watermark: bool = True, ): - payload = Text2ImageTaskCreationRequest( - model=model, - input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), - parameters=Txt2ImageParametersField( - size=f"{width}*{height}", - seed=seed, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", - request_model=Text2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=9, poll_interval=3, ) @@ -320,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -376,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -408,28 +360,30 @@ class WanImageToImageApi(IO.ComfyNode): raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") images = [] for i in image: - images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) - payload = Image2ImageTaskCreationRequest( - model=model, - input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), - parameters=Image2ImageParametersField( - # size=f"{width}*{height}", - seed=seed, - watermark=watermark, + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", - request_model=Image2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=42, - poll_interval=3, + poll_interval=4, ) return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) @@ -523,7 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -557,28 +511,31 @@ class WanTextToVideoApi(IO.ComfyNode): if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Text2VideoTaskCreationRequest( - model=model, - input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), - parameters=Text2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Text2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) @@ -667,7 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -699,35 +656,37 @@ class WanImageToVideoApi(IO.ComfyNode): ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Image2VideoTaskCreationRequest( - model=model, - input=Image2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url - ), - parameters=Image2VideoParametersField( - resolution=resolution, - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Image2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb..c2ec391aa 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,87 @@ +from ._helpers import get_fs_object_size +from .client import ( + ApiEndpoint, + poll_op, + poll_op_raw, + sync_op, + sync_op_raw, +) +from .conversions import ( + audio_bytes_to_audio_input, + audio_input_to_mp3, + audio_to_base64_string, + bytesio_to_image_tensor, + downscale_image_tensor, + image_tensor_pair_to_batch, + pil_to_bytesio, + tensor_to_base64_string, + tensor_to_bytesio, + tensor_to_pil, + trim_video, +) +from .download_helpers import ( + download_url_to_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, +) +from .upload_helpers import ( + upload_audio_to_comfyapi, + upload_file_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, +) +from .validation_utils import ( + get_number_of_images, + validate_aspect_ratio_closeness, + validate_audio_duration, + validate_container_format_is_mp4, + validate_image_aspect_ratio, + validate_image_aspect_ratio_range, + validate_image_dimensions, + validate_string, + validate_video_dimensions, + validate_video_duration, +) + +__all__ = [ + # API client + "ApiEndpoint", + "poll_op", + "poll_op_raw", + "sync_op", + "sync_op_raw", + # Upload helpers + "upload_audio_to_comfyapi", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", + "upload_video_to_comfyapi", + # Download helpers + "download_url_to_bytesio", + "download_url_to_image_tensor", + "download_url_to_video_output", + # Conversions + "audio_bytes_to_audio_input", + "audio_input_to_mp3", + "audio_to_base64_string", + "bytesio_to_image_tensor", + "downscale_image_tensor", + "image_tensor_pair_to_batch", + "pil_to_bytesio", + "tensor_to_base64_string", + "tensor_to_bytesio", + "tensor_to_pil", + "trim_video", + # Validation utilities + "get_number_of_images", + "validate_aspect_ratio_closeness", + "validate_audio_duration", + "validate_container_format_is_mp4", + "validate_image_aspect_ratio", + "validate_image_aspect_ratio_range", + "validate_image_dimensions", + "validate_string", + "validate_video_dimensions", + "validate_video_duration", + # Misc functions + "get_fs_object_size", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000..328fe5227 --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,71 @@ +import asyncio +import contextlib +import os +import time +from io import BytesIO +from typing import Callable, Optional, Union + +from comfy.cli_args import args +from comfy.model_management import processing_interrupted +from comfy_api.latest import IO + +from .common_exceptions import ProcessingInterrupted + + +def is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def sleep_with_interrupt( + seconds: float, + node_cls: Optional[type[IO.ComfyNode]], + label: Optional[str] = None, + start_ts: Optional[float] = None, + estimated_total: Optional[int] = None, + *, + display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: + if isinstance(path_or_object, str): + return os.path.getsize(path_or_object) + return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py new file mode 100644 index 000000000..5833b118f --- /dev/null +++ b/comfy_api_nodes/util/client.py @@ -0,0 +1,941 @@ +import asyncio +import contextlib +import json +import logging +import socket +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from urllib.parse import urljoin, urlparse + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from pydantic import BaseModel + +from comfy import utils +from comfy_api.latest import IO +from comfy_api_nodes.apis import request_logger +from server import PromptServer + +from ._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, + sleep_with_interrupt, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: Optional[dict[str, Any]] + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] + multipart_parser: Optional[Callable] + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: Optional[int] = None + final_label_on_success: Optional[str] = "Completed" + progress_origin_ts: Optional[float] = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: Optional[float] = None + estimated_duration: Optional[int] = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: Optional[float] = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] +FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: Type[M], + data: Optional[BaseModel] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: Optional[int] = None, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op_raw( + cls, + endpoint, + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_duration=estimated_duration, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + response_model: Type[M], + status_extractor: Callable[[M], Optional[Union[str, int]]], + progress_extractor: Optional[Callable[[M], Optional[int]]] = None, + price_extractor: Optional[Callable[[M], Optional[float]]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[BaseModel] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op_raw( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_statuses=queued_statuses, + data=data, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op_raw( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: Optional[int] = None, + as_binary: bool = False, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> Union[dict[str, Any], bytes]: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_duration, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op_raw( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], + progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + + Uses default complete, failed and queued states assumption. + + Returns the final JSON response from the poll endpoint. + """ + completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) + failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) + queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: Optional[int] = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + status=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op_raw( + cls, + poll_endpoint, + data=data, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = _normalize_status_value(status_extractor(resp_json)) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_states: + if state.active_since is not None: + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + status=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_states: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: Optional[str], + *, + status: Optional[Union[str, int]] = None, + price: Optional[float] = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + if price is not None: + display_lines.append(f"Price: ${float(price):,.4f}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + status: Optional[Union[str, int]], + elapsed_seconds: int, + estimated_total: Optional[int] = None, + *, + price: Optional[float] = None, + is_queued: Optional[bool] = None, + processing_elapsed_seconds: Optional[int] = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=status, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results + + parsed = urlparse(default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, asyncio.TimeoutError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: Optional[dict[str, Any]], + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], +) -> Optional[Union[dict[str, Any], str]]: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = cfg.endpoint.path + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: Optional[int] = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: Optional[asyncio.Task] = None + sess: Optional[aiohttp.ClientSession] = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = {"Accept": "*/*"} + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + payload_headers.update(get_auth_header(cfg.node_cls)) + if cfg.endpoint.headers: + payload_headers.update(cfg.endpoint.headers) + + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + if expect_binary: + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, + url, + delay, + attempt, + cfg.max_retries, + str(e), + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + status=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=None, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: Type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: Type[M], + extractor: Optional[Callable[[M], Any]], +) -> Optional[Callable[[dict[str, Any]], Any]]: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped + + +def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: + if not values: + return set() + out: set[Union[str, int]] = set() + for v in values: + nv = _normalize_status_value(v) + if nv is not None: + out.add(nv) + return out + + +def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: + if isinstance(val, str): + return val.strip().lower() + return val diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000..0606a4407 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py new file mode 100644 index 000000000..10cd1051b --- /dev/null +++ b/comfy_api_nodes/util/conversions.py @@ -0,0 +1,407 @@ +import base64 +import logging +import math +import uuid +from io import BytesIO +from typing import Optional + +import av +import numpy as np +import torch +from PIL import Image + +from comfy.utils import common_upscale +from comfy_api.latest import Input, InputImpl + +from ._helpers import mimetype_to_extension + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) + + +def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) + + +def tensor_to_bytesio( + image: torch.Tensor, + name: Optional[str] = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data, with pointer set to the start of buffer. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = tensor_to_pil(image, total_pixels=total_pixels) + img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + return img_binary + + +def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + +def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: + """ + Returns a new VideoInput object trimmed from the beginning to the specified duration, + using av to avoid loading entire video into memory. + + Args: + video: Input video to trim + duration_sec: Duration in seconds to keep from the beginning + + Returns: + VideoFromFile object that owns the output buffer + """ + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + # Get the stream source - this avoids loading entire video into memory + # when the source is already a file path + input_source = video.get_stream_source() + + # Open containers + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + # Set up output streams for re-encoding + video_stream = None + audio_stream = None + + for stream in input_container.streams: + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) + if isinstance(stream, av.VideoStream): + # Create output video stream with same parameters + video_stream = output_container.add_stream("h264", rate=stream.average_rate) + video_stream.width = stream.width + video_stream.height = stream.height + video_stream.pix_fmt = "yuv420p" + logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) + elif isinstance(stream, av.AudioStream): + # Create output audio stream with same parameters + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) + + # Calculate target frame count that's divisible by 16 + fps = input_container.streams.video[0].average_rate + estimated_frames = int(duration_sec * fps) + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + + if target_frames == 0: + raise ValueError("Video too short: need at least 16 frames for Moonvalley") + + frame_count = 0 + audio_frame_count = 0 + + # Decode and re-encode video frames + if video_stream: + for frame in input_container.decode(video=0): + if frame_count >= target_frames: + break + + # Re-encode frame + for packet in video_stream.encode(frame): + output_container.mux(packet) + frame_count += 1 + + # Flush encoder + for packet in video_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) + + # Decode and re-encode audio frames + if audio_stream: + input_container.seek(0) # Reset to beginning for audio + for frame in input_container.decode(audio=0): + if frame.time >= duration_sec: + break + + # Re-encode frame + for packet in audio_stream.encode(frame): + output_container.mux(packet) + audio_frame_count += 1 + + # Flush encoder + for packet in audio_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s audio frames", audio_frame_count) + + # Close containers + output_container.close() + input_container.close() + + # Return as VideoFromFile using the buffer + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + # Clean up on error + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to trim video: {str(e)}") from e + + +def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2**15) + elif wav.dtype == torch.int32: + return wav.float() / (2**31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = _f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py new file mode 100644 index 000000000..055e690de --- /dev/null +++ b/comfy_api_nodes/util/download_helpers.py @@ -0,0 +1,249 @@ +import asyncio +import contextlib +import uuid +from io import BytesIO +from pathlib import Path +from typing import IO, Optional, Union +from urllib.parse import urljoin, urlparse + +import aiohttp +import torch +from aiohttp.client_exceptions import ClientError, ContentTypeError + +from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import IO as COMFY_IO +from comfy_api_nodes.apis import request_logger + +from ._helpers import ( + default_base_url, + get_auth_header, + is_processing_interrupted, + sleep_with_interrupt, +) +from .client import _diagnose_connectivity +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import bytesio_to_image_tensor + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def download_url_to_bytesio( + url: str, + dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + *, + timeout: Optional[float] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + cls: type[COMFY_IO.ComfyNode] = None, +) -> None: + """Stream-download a URL to `dest`. + + `dest` must be one of: + - a BytesIO (rewound to 0 after write), + - a file-like object opened in binary write mode (must implement .write()), + - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. + + If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded + to an absolute URL and authentication headers can be applied. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) + """ + if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): + raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + + attempt = 0 + delay = retry_delay + headers: dict[str, str] = {} + + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + if cls is None: + raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + headers = get_auth_header(cls) + + while True: + attempt += 1 + op_id = _generate_operation_id("GET", url, attempt) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + + is_path_sink = isinstance(dest, (str, Path)) + fhandle = None + session: Optional[aiohttp.ClientSession] = None + stop_evt: Optional[asyncio.Event] = None + monitor_task: Optional[asyncio.Task] = None + req_task: Optional[asyncio.Task] = None + + try: + with contextlib.suppress(Exception): + request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + + session = aiohttp.ClientSession(timeout=timeout_cfg) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + + req_task = asyncio.create_task(session.get(url, headers=headers)) + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + raise ProcessingInterrupted("Task cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + while True: + try: + chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + except asyncio.TimeoutError: + chunk = b"" + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if not chunk: + if resp.content.at_eof(): + break + continue + + sink.write(chunk) + written += len(chunk) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The remote service appears unreachable at this time.") from e + finally: + if stop_evt is not None: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if req_task and not req_task.done(): + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + if session: + with contextlib.suppress(Exception): + await session.close() + if fhandle: + with contextlib.suppress(Exception): + fhandle.flush() + fhandle.close() + + +async def download_url_to_image_tensor( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> torch.Tensor: + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return bytesio_to_image_tensor(result) + + +async def download_url_to_video_output( + video_url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output.""" + result = BytesIO() + await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) + return VideoFromFile(result) + + +def _generate_operation_id(method: str, url: str, attempt: int) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + except Exception: + slug = "download" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py new file mode 100644 index 000000000..a345d451d --- /dev/null +++ b/comfy_api_nodes/util/upload_helpers.py @@ -0,0 +1,338 @@ +import asyncio +import contextlib +import logging +import time +import uuid +from io import BytesIO +from typing import Optional, Union +from urllib.parse import urlparse + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO, Input +from comfy_api.util import VideoCodec, VideoContainer +from comfy_api_nodes.apis import request_logger + +from ._helpers import is_processing_interrupted, sleep_with_interrupt +from .client import ( + ApiEndpoint, + _diagnose_connectivity, + _display_time_progress, + sync_op, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import ( + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + tensor_to_bytesio, +) + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: Optional[str] = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: Optional[str] = None, + wait_label: Optional[str] = "Uploading", +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batch, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + download_urls.append(url) + return download_urls + + +async def upload_audio_to_comfyapi( + cls: type[IO.ComfyNode], + audio: Input.Audio, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + + +async def upload_video_to_comfyapi( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + container: VideoContainer = VideoContainer.MP4, + codec: VideoCodec = VideoCodec.H264, + max_duration: Optional[int] = None, +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + """ + if max_duration is not None: + try: + actual_duration = video.get_duration() + if actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error("Error getting video duration: %s", str(e)) + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: Optional[str], + wait_label: Optional[str] = "Uploading", +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, + create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label=wait_label, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: Union[BytesIO, str], + *, + content_type: Optional[str] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: Optional[str] = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Args: + cls: Node class (provides auth context + UI progress hooks). + upload_url: Pre-signed PUT URL. + file: BytesIO or path string. + content_type: Explicit MIME type. If None, we *suppress* Content-Type. + max_retries: Maximum retry attempts. + retry_delay: Initial delay in seconds. + retry_backoff: Exponential backoff factor. + wait_label: Progress label shown in Comfy UI. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: Optional[aiohttp.ClientSession] = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Upload cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ca913e9b3..22da05bc1 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,6 +2,8 @@ import logging from typing import Optional import torch + +from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -28,9 +30,7 @@ def validate_image_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Image height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Image height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") @@ -44,13 +44,9 @@ def validate_image_aspect_ratio( aspect_ratio = width / height if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}") if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}") def validate_image_aspect_ratio_range( @@ -58,7 +54,7 @@ def validate_image_aspect_ratio_range( min_ratio: tuple[float, float], # e.g. (1, 4) max_ratio: tuple[float, float], # e.g. (4, 1) *, - strict: bool = True, # True -> (min, max); False -> [min, max] + strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: a1, b1 = min_ratio a2, b2 = max_ratio @@ -85,7 +81,7 @@ def validate_aspect_ratio_closeness( min_rel: float, max_rel: float, *, - strict: bool = False, # True => exclusive, False => inclusive + strict: bool = False, # True => exclusive, False => inclusive ) -> None: w1, h1 = get_image_dimensions(start_img) w2, h2 = get_image_dimensions(end_img) @@ -118,9 +114,7 @@ def validate_video_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Video height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Video height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") @@ -138,13 +132,9 @@ def validate_video_duration( epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: - raise ValueError( - f"Video duration must be at least {min_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s") if max_duration is not None and duration > max_duration + epsilon: - raise ValueError( - f"Video duration must be at most {max_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") def get_number_of_images(images): @@ -165,3 +155,31 @@ def validate_audio_duration( raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") if max_duration is not None and dur - eps > max_duration: raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") + + +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." + ) + if max_length and len(string) > max_length: + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def validate_container_format_is_mp4(video: VideoInput) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/pyproject.toml b/pyproject.toml index 0c6b23a25..fcc4854a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ messages_control.disable = [ "too-many-branches", "too-many-locals", "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", "duplicate-code", "abstract-method", "superfluous-parens", From dd5af0c5871376c377b2e30f9725b67a768eea6f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 25 Oct 2025 01:48:34 +0300 Subject: [PATCH 008/231] convert Tripo API nodes to V3 schema (#10469) --- comfy_api_nodes/apis/tripo_api.py | 15 +- comfy_api_nodes/nodes_tripo.py | 892 ++++++++++++----------- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/download_helpers.py | 12 + 4 files changed, 503 insertions(+), 418 deletions(-) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 9f43d4d09..713260e2a 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -1,13 +1,20 @@ from __future__ import annotations -from comfy_api_nodes.apis import ( - TripoModelVersion, - TripoTextureQuality, -) from enum import Enum from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel +class TripoModelVersion(str, Enum): + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' + + +class TripoTextureQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoStyle(str, Enum): PERSON_TO_CARTOON = "person:person2cartoon" ANIMAL_VENOM = "animal:venom" diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index d08cf9007..697100ff2 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,46 +1,39 @@ import os -from folder_paths import get_output_directory -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.apis import ( - TripoOrientation, - TripoModelVersion, -) +from typing import Optional + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.tripo_api import ( - TripoTaskType, - TripoStyle, - TripoFileReference, + TripoAnimateRetargetRequest, + TripoAnimateRigRequest, + TripoConvertModelRequest, TripoFileEmptyReference, - TripoUrlReference, + TripoFileReference, + TripoImageToModelRequest, + TripoModelVersion, + TripoMultiviewToModelRequest, + TripoOrientation, + TripoRefineModelRequest, + TripoStyle, TripoTaskResponse, TripoTaskStatus, + TripoTaskType, TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoConvertModelRequest, + TripoUrlReference, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_as_bytesio, + poll_op, + sync_op, upload_images_to_comfyapi, - download_url_to_bytesio, ) +from folder_paths import get_output_directory -async def upload_image_to_tripo(image, **kwargs): - urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) - return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) - def get_model_url_from_response(response: TripoTaskResponse) -> str: if response.data is not None: for key in ["pbr_model", "model", "base_model"]: @@ -50,20 +43,18 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( - kwargs: dict[str, str], + node_cls: type[IO.ComfyNode], response: TripoTaskResponse, -) -> tuple[str, str]: + average_duration: Optional[int] = None, +) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/tripo/v2/openapi/task/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TripoTaskResponse, - ), + response_poll = await poll_op( + node_cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"), + response_model=TripoTaskResponse, completed_statuses=[TripoTaskStatus.SUCCESS], failed_statuses=[ TripoTaskStatus.FAILED, @@ -73,72 +64,84 @@ async def poll_until_finished( TripoTaskStatus.EXPIRED, ], status_extractor=lambda x: x.data.status, - auth_kwargs=kwargs, - node_id=kwargs["unique_id"], - result_url_extractor=get_model_url_from_response, progress_extractor=lambda x: x.data.progress, - ).execute() + estimated_duration=average_duration, + ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_to_bytesio(url) + bytesio = await download_url_as_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: f.write(bytesio.getvalue()) - return model_file, task_id + return IO.NodeOutput(model_file, task_id) raise RuntimeError(f"Failed to generate mesh: {response_poll}") -class TripoTextToModelNode: +class TripoTextToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a text prompt using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ("STRING", {"multiline": True}), - }, - "optional": { - "negative_prompt": ("STRING", {"multiline": True}), - "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "image_seed": ("INT", {"default": 42}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextToModelNode", + display_name="Tripo: Text to Model", + category="api node/3d/Tripo", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.String.Input("negative_prompt", multiline=True, optional=True), + IO.Combo.Input( + "model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("image_seed", default=42, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: Optional[str] = None, + model_version=None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + image_seed: Optional[int] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextToModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextToModelRequest( type=TripoTaskType.TEXT_TO_MODEL, prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, @@ -152,64 +155,89 @@ class TripoTextToModelNode: texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoImageToModelNode: +class TripoImageToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a single image using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoImageToModelNode", + display_name="Tripo: Image to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + tooltip="The model version to use for generation", + optional=True, + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Combo.Input( + "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True + ), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + model_version: Optional[str] = None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + orientation=None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = await upload_image_to_tripo(image, **kwargs) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoImageToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoImageToModelRequest( + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoImageToModelRequest( type=TripoTaskType.IMAGE_TO_MODEL, file=tripo_file, model_version=model_version, @@ -223,80 +251,105 @@ class TripoImageToModelNode: texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoMultiviewToModelNode: +class TripoMultiviewToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "image_left": ("IMAGE",), - "image_back": ("IMAGE",), - "image_right": ("IMAGE",), - "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoMultiviewToModelNode", + display_name="Tripo: Multiview to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + optional=True, + tooltip="The model version to use for generation", + ), + IO.Combo.Input( + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + ), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + image_left: Optional[torch.Tensor] = None, + image_back: Optional[torch.Tensor] = None, + image_right: Optional[torch.Tensor] = None, + model_version: Optional[str] = None, + orientation: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") images = [] - image_dict = { - "image": image, - "image_left": image_left, - "image_back": image_back, - "image_right": image_right - } + image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right} if image_left is None and image_back is None and image_right is None: raise RuntimeError("At least one of left, back, or right image must be provided for multiview") for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = await upload_image_to_tripo(image_, **kwargs) - images.append(tripo_file) + images.append( + TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg" + ) + ) + ) else: images.append(TripoFileEmptyReference()) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoMultiviewToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoMultiviewToModelRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoMultiviewToModelRequest( type=TripoTaskType.MULTIVIEW_TO_MODEL, files=images, model_version=model_version, @@ -310,272 +363,283 @@ class TripoMultiviewToModelNode: face_limit=face_limit, quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoTextureNode: +class TripoTextureNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID",), - }, - "optional": { - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextureNode", + display_name="Tripo: Texture model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id"), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 80 - - async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextureModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextureModelRequest( + @classmethod + async def execute( + cls, + model_task_id, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextureModelRequest( original_model_task_id=model_task_id, texture=texture, pbr=pbr, texture_seed=texture_seed, texture_quality=texture_quality, - texture_alignment=texture_alignment + texture_alignment=texture_alignment, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoRefineNode: +class TripoRefineNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID", { - "tooltip": "Must be a v1.4 Tripo model" - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRefineNode", + display_name="Tripo: Refine Draft model", + category="api node/3d/Tripo", + description="Refine a draft model created by v1.4 Tripo models only.", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only." - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 240 - - async def generate_mesh(self, model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoRefineModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoRefineModelRequest( - draft_model_task_id=model_task_id - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) - - -class TripoRigNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID",), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "RIG_TASK_ID") - RETURN_NAMES = ("model_file", "rig task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 180 - - async def generate_mesh(self, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRigRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRigRequest( - original_model_task_id=original_model_task_id, - out_format="glb", - spec="tripo" - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + async def execute(cls, model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoRefineModelRequest(draft_model_task_id=model_task_id), + ) + return await poll_until_finished(cls, response, average_duration=240) -class TripoRetargetNode: +class TripoRigNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("RIG_TASK_ID",), - "animation": ([ - "preset:idle", - "preset:walk", - "preset:climb", - "preset:jump", - "preset:slash", - "preset:shoot", - "preset:hurt", - "preset:fall", - "preset:turn", - ],), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRigNode", + display_name="Tripo: Rig model", + category="api node/3d/Tripo", + inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "RETARGET_TASK_ID") - RETURN_NAMES = ("model_file", "retarget task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 + @classmethod + async def execute(cls, original_model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"), + ) + return await poll_until_finished(cls, response, average_duration=180) - async def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRetargetRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRetargetRequest( + +class TripoRetargetNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoRetargetNode", + display_name="Tripo: Retarget rigged model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input( + "animation", + options=[ + "preset:idle", + "preset:walk", + "preset:climb", + "preset:jump", + "preset:slash", + "preset:shoot", + "preset:hurt", + "preset:fall", + "preset:turn", + ], + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRetargetRequest( original_model_task_id=original_model_task_id, animation=animation, out_format="glb", - bake_animation=True + bake_animation=True, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -class TripoConversionNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",), - "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],), - }, - "optional": { - "quad": ("BOOLEAN", {"default": False}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}), - "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } +class TripoConversionNode(IO.ComfyNode): @classmethod - def VALIDATE_INPUTS(cls, input_types): + def define_schema(cls): + return IO.Schema( + node_id="TripoConversionNode", + display_name="Tripo: Convert model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=500000, + optional=True, + ), + IO.Int.Input( + "texture_size", + default=4096, + min=128, + max=4096, + optional=True, + ), + IO.Combo.Input( + "texture_format", + options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], + default="JPEG", + optional=True, + ), + ], + outputs=[], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + def validate_inputs(cls, input_types): # The min and max of input1 and input2 are still validated because # we didn't take `input1` or `input2` as arguments if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" return True - RETURN_TYPES = () - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 - - async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + @classmethod + async def execute( + cls, + original_model_task_id, + format: str, + quad: bool, + face_limit: int, + texture_size: int, + texture_format: str, + ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoConvertModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoConvertModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoConvertModelRequest( original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, face_limit=face_limit if face_limit != -1 else None, texture_size=texture_size if texture_size != 4096 else None, - texture_format=texture_format if texture_format != "JPEG" else None + texture_format=texture_format if texture_format != "JPEG" else None, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -NODE_CLASS_MAPPINGS = { - "TripoTextToModelNode": TripoTextToModelNode, - "TripoImageToModelNode": TripoImageToModelNode, - "TripoMultiviewToModelNode": TripoMultiviewToModelNode, - "TripoTextureNode": TripoTextureNode, - "TripoRefineNode": TripoRefineNode, - "TripoRigNode": TripoRigNode, - "TripoRetargetNode": TripoRetargetNode, - "TripoConversionNode": TripoConversionNode, -} +class TripoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoTextToModelNode, + TripoImageToModelNode, + TripoMultiviewToModelNode, + TripoTextureNode, + TripoRefineNode, + TripoRigNode, + TripoRetargetNode, + TripoConversionNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "TripoTextToModelNode": "Tripo: Text to Model", - "TripoImageToModelNode": "Tripo: Image to Model", - "TripoMultiviewToModelNode": "Tripo: Multiview to Model", - "TripoTextureNode": "Tripo: Texture model", - "TripoRefineNode": "Tripo: Refine Draft model", - "TripoRigNode": "Tripo: Rig model", - "TripoRetargetNode": "Tripo: Retarget rigged model", - "TripoConversionNode": "Tripo: Convert model", -} + +async def comfy_entrypoint() -> TripoExtension: + return TripoExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index c2ec391aa..ab96760cb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -20,6 +20,7 @@ from .conversions import ( trim_video, ) from .download_helpers import ( + download_url_as_bytesio, download_url_to_bytesio, download_url_to_image_tensor, download_url_to_video_output, @@ -56,6 +57,7 @@ __all__ = [ "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers + "download_url_as_bytesio", "download_url_to_bytesio", "download_url_to_image_tensor", "download_url_to_video_output", diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 055e690de..791dd5a50 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -240,6 +240,18 @@ async def download_url_to_video_output( return VideoFromFile(result) +async def download_url_as_bytesio( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> BytesIO: + """Downloads content from a URL and returns a new BytesIO (rewound to 0).""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return result + + def _generate_operation_id(method: str, url: str, attempt: int) -> str: try: parsed = urlparse(url) From 426cde37f10dc391f9601ab938e02c0faa42db14 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:56:51 -0700 Subject: [PATCH 009/231] Remove useless function (#10472) --- comfy/model_management.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79d6ff9d4..cf015a29a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -998,12 +998,6 @@ def device_supports_non_blocking(device): return False return True -def device_should_use_non_blocking(device): - if not device_supports_non_blocking(device): - return False - return False - # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others - def force_channels_last(): if args.force_channels_last: return True From e86b79ab9ea7e740b80490353f3f5763840ede81 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 26 Oct 2025 00:35:30 +0300 Subject: [PATCH 010/231] convert Gemini API nodes to V3 schema (#10476) --- comfy_api_nodes/apinode_utils.py | 26 -- comfy_api_nodes/nodes_gemini.py | 629 +++++++++++----------------- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 25 ++ 4 files changed, 282 insertions(+), 400 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index e3d282059..4182c8f80 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -3,8 +3,6 @@ import aiohttp import mimetypes from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api.util import VideoContainer, VideoCodec -from comfy_api.input.video_types import VideoInput from comfy_api_nodes.apis.client import ( ApiClient, ApiEndpoint, @@ -209,30 +207,6 @@ async def upload_file_to_comfyapi( return response.download_url -def video_to_base64_string( - video: VideoInput, - container_format: VideoContainer = None, - codec: VideoCodec = None -) -> str: - """ - Converts a video input to a base64 string. - - Args: - video: The video input to convert - container_format: Optional container format to use (defaults to video.container if available) - codec: Optional codec to use (defaults to video.codec if available) - """ - video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) - video_bytes_io.seek(0) - return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") - - async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ca11b67ed..67f2469ad 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -2,42 +2,47 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ + from __future__ import annotations -import json -import time -import os -import uuid import base64 -from io import BytesIO +import json +import os +import time +import uuid from enum import Enum -from typing import Optional, Literal +from io import BytesIO +from typing import Literal, Optional import torch +from typing_extensions import override import folder_paths -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis import ( GeminiContent, GeminiGenerateContentRequest, GeminiGenerateContentResponse, GeminiInlineData, - GeminiPart, GeminiMimeType, + GeminiPart, ) -from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.apis.gemini_api import ( + GeminiImageConfig, + GeminiImageGenerateContentRequest, + GeminiImageGenerationConfig, +) +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( + audio_to_base64_string, + bytesio_to_image_tensor, + sync_op, + tensor_to_base64_string, + validate_string, video_to_base64_string, ) -from comfy_api_nodes.util import validate_string, tensor_to_base64_string, bytesio_to_image_tensor, audio_to_base64_string -from comfy_api.util import VideoContainer, VideoCodec - +from server import PromptServer GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB @@ -63,50 +68,6 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def get_gemini_endpoint( - model: GeminiModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - -def get_gemini_image_endpoint( - model: GeminiImageModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiImageModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiImageGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: """ Convert image tensor input to Gemini API compatible parts. @@ -119,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: """ image_parts: list[GeminiPart] = [] for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string( - image_input[image_index].unsqueeze(0) - ) + image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) image_parts.append( GeminiPart( inlineData=GeminiInlineData( @@ -133,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: return image_parts -def create_text_part(text: str) -> GeminiPart: - """ - Create a text part for the Gemini API request. - - Args: - text: The text content to include in the request. - - Returns: - A GeminiPart object with the text content. - """ - return GeminiPart(text=text) - - -def get_parts_from_response( - response: GeminiGenerateContentResponse -) -> list[GeminiPart]: - """ - Extract all parts from the Gemini API response. - - Args: - response: The API response from Gemini. - - Returns: - List of response parts from the first candidate. - """ - return response.candidates[0].content.parts - - -def get_parts_by_type( - response: GeminiGenerateContentResponse, part_type: Literal["text"] | str -) -> list[GeminiPart]: +def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: """ Filter response parts by their type. @@ -175,14 +104,10 @@ def get_parts_by_type( List of response parts matching the requested type. """ parts = [] - for part in get_parts_from_response(response): + for part in response.candidates[0].content.parts: if part_type == "text" and hasattr(part, "text") and part.text: parts.append(part) - elif ( - hasattr(part, "inlineData") - and part.inlineData - and part.inlineData.mimeType == part_type - ): + elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -210,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te returned_image = bytesio_to_image_tensor(BytesIO(image_data)) image_tensors.append(returned_image) if len(image_tensors) == 0: - return torch.zeros((1,1024,1024,4)) + return torch.zeros((1, 1024, 1024, 4)) return torch.cat(image_tensors, dim=0) -class GeminiNode(ComfyNodeABC): +class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -225,96 +150,79 @@ class GeminiNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.", - }, + def define_schema(cls): + return IO.Schema( + node_id="GeminiNode", + display_name="Google Gemini", + category="api node/text/Gemini", + description="Generate text responses with Google's Gemini AI model. " + "You can provide multiple types of inputs (text, images, audio, video) " + "as context for generating more relevant and meaningful responses.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text inputs to the model, used to generate a response. " + "You can include detailed instructions, questions, or context for the model.", ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro.value, - }, + IO.Combo.Input( + "model", + options=GeminiModel, + default=GeminiModel.gemini_2_5_pro, + tooltip="The Gemini model to use for generating responses.", ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", ), - "audio": ( - IO.AUDIO, - { - "tooltip": "Optional audio to use as context for the model.", - "default": None, - }, + IO.Audio.Input( + "audio", + optional=True, + tooltip="Optional audio to use as context for the model.", ), - "video": ( - IO.VIDEO, - { - "tooltip": "Optional video to use as context for the model.", - "default": None, - }, + IO.Video.Input( + "video", + optional=True, + tooltip="Optional video to use as context for the model.", ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses." - RETURN_TYPES = ("STRING",) - FUNCTION = "api_call" - CATEGORY = "api node/text/Gemini" - API_NODE = True - - def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: - """ - Convert video input to Gemini API compatible parts. - - Args: - video_input: Video tensor from ComfyUI. - **kwargs: Additional arguments to pass to the conversion function. - - Returns: - List of GeminiPart objects containing the encoded video. - """ - - base_64_string = video_to_base64_string( - video_input, - container_format=VideoContainer.MP4, - codec=VideoCodec.H264 + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) + + @classmethod + def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: + """Convert video input to Gemini API compatible parts.""" + + base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) return [ GeminiPart( inlineData=GeminiInlineData( @@ -324,7 +232,8 @@ class GeminiNode(ComfyNodeABC): ) ] - def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: + @classmethod + def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: """ Convert audio input to Gemini API compatible parts. @@ -337,10 +246,10 @@ class GeminiNode(ComfyNodeABC): audio_parts: list[GeminiPart] = [] for batch_index in range(audio_input["waveform"].shape[0]): # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = { - "waveform": audio_input["waveform"][batch_index].unsqueeze(0), - "sample_rate": audio_input["sample_rate"], - } + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) # Convert to MP3 format for compatibility with Gemini API audio_bytes = audio_to_base64_string( audio_at_index, @@ -357,38 +266,38 @@ class GeminiNode(ComfyNodeABC): ) return audio_parts - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - model: GeminiModel, - images: Optional[IO.IMAGE] = None, - audio: Optional[IO.AUDIO] = None, - video: Optional[IO.VIDEO] = None, + model: str, + seed: int, + images: Optional[torch.Tensor] = None, + audio: Optional[Input.Audio] = None, + video: Optional[Input.Video] = None, files: Optional[list[GeminiPart]] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # Create parts list with text prompt as the first part - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] # Add other modal parts if images is not None: image_parts = create_image_parts(images) parts.extend(image_parts) if audio is not None: - parts.extend(self.create_audio_parts(audio)) + parts.extend(cls.create_audio_parts(audio)) if video is not None: - parts.extend(self.create_video_parts(video)) + parts.extend(cls.create_video_parts(video)) if files is not None: parts.extend(files) # Create response - response = await SynchronousOperation( - endpoint=get_gemini_endpoint(model), - request=GeminiGenerateContentRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiGenerateContentRequest( contents=[ GeminiContent( role="user", @@ -396,15 +305,15 @@ class GeminiNode(ComfyNodeABC): ) ] ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + ) # Get result output output_text = get_text_from_response(response) - if unique_id and output_text: + if output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { - "node_id": unique_id, + "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", "props": { "history": json.dumps( @@ -424,10 +333,10 @@ class GeminiNode(ComfyNodeABC): render_spec, ) - return (output_text or "Empty response from Gemini model...",) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") -class GeminiInputFiles(ComfyNodeABC): +class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -438,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference @@ -453,39 +362,37 @@ class GeminiInputFiles(ComfyNodeABC): ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="GeminiInputFiles", + display_name="Gemini Input Files", + category="api node/text/Gemini", + description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " + "The files will be read by the Gemini model when generating a response. " + "The contents of the text file count toward the token limit. " + "🛈 TIP: Can be chained together with other Gemini Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. " + "Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "GEMINI_INPUT_FILES": ( + IO.Custom("GEMINI_INPUT_FILES").Input( "GEMINI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + optional=True, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. " + "Allows chaining of input files so that a single message can include multiple input files.", ), - }, - } - - DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes." - RETURN_TYPES = ("GEMINI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/Gemini" - - def create_file_part(self, file_path: str) -> GeminiPart: - mime_type = ( - GeminiMimeType.application_pdf - if file_path.endswith(".pdf") - else GeminiMimeType.text_plain + ], + outputs=[ + IO.Custom("GEMINI_INPUT_FILES").Output(), + ], ) + + @classmethod + def create_file_part(cls, file_path: str) -> GeminiPart: + mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() @@ -498,120 +405,95 @@ class GeminiInputFiles(ComfyNodeABC): ) ) - def prepare_files( - self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = [] - ) -> tuple[list[GeminiPart]]: - """ - Loads and formats input files for Gemini API. - """ - file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_file_part(file_path) - files = [input_file_content] + GEMINI_INPUT_FILES - return (files,) - - -class GeminiImage(ComfyNodeABC): - """ - Node to generate text and image responses from a Gemini model. - - This node allows users to interact with Google's Gemini AI models, providing - multimodal inputs (text, images, files) to generate coherent - text and image responses. The node works with the latest Gemini models, handling the - API communication and response parsing. - """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for generation", - }, - ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiImageModel], - "default": GeminiImageModel.gemini_2_5_flash_image.value, - }, - ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, - ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, - ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, - ), - # TODO: later we can add this parameter later - # "n": ( - # IO.INT, - # { - # "default": 1, - # "min": 1, - # "max": 8, - # "step": 1, - # "display": "number", - # "tooltip": "How many images to generate", - # }, - # ), - "aspect_ratio": ( - IO.COMBO, - { - "tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.", - "options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], - "default": "auto", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput: + """Loads and formats input files for Gemini API.""" + if GEMINI_INPUT_FILES is None: + GEMINI_INPUT_FILES = [] + file_path = folder_paths.get_annotated_filepath(file) + input_file_content = cls.create_file_part(file_path) + return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES) - RETURN_TYPES = (IO.IMAGE, IO.STRING) - FUNCTION = "api_call" - CATEGORY = "api node/image/Gemini" - DESCRIPTION = "Edit images synchronously via Google API." - API_NODE = True - async def api_call( - self, +class GeminiImage(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImageNode", + display_name="Google Gemini Image", + category="api node/image/Gemini", + description="Edit images synchronously via Google API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt for generation", + default="", + ), + IO.Combo.Input( + "model", + options=GeminiImageModel, + default=GeminiImageModel.gemini_2_5_flash_image, + tooltip="The Gemini model to use for generating responses.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="Defaults to matching the output image size to that of your input image, " + "or otherwise generates 1:1 squares.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, prompt: str, - model: GeminiImageModel, - images: Optional[IO.IMAGE] = None, + model: str, + seed: int, + images: Optional[torch.Tensor] = None, files: Optional[list[GeminiPart]] = None, - n=1, aspect_ratio: str = "auto", - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] if not aspect_ratio: aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December @@ -623,29 +505,27 @@ class GeminiImage(ComfyNodeABC): if files is not None: parts.extend(files) - response = await SynchronousOperation( - endpoint=get_gemini_image_endpoint(model), - request=GeminiImageGenerateContentRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( contents=[ - GeminiContent( - role="user", - parts=parts, - ), + GeminiContent(role="user", parts=parts), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT","IMAGE"], + responseModalities=["TEXT", "IMAGE"], imageConfig=None if aspect_ratio == "auto" else image_config, - ) + ), ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + ) output_image = get_image_from_response(response) output_text = get_text_from_response(response) - if unique_id and output_text: + if output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { - "node_id": unique_id, + "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", "props": { "history": json.dumps( @@ -666,17 +546,18 @@ class GeminiImage(ComfyNodeABC): ) output_text = output_text or "Empty response from Gemini model..." - return (output_image, output_text,) + return IO.NodeOutput(output_image, output_text) -NODE_CLASS_MAPPINGS = { - "GeminiNode": GeminiNode, - "GeminiImageNode": GeminiImage, - "GeminiInputFiles": GeminiInputFiles, -} +class GeminiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GeminiNode, + GeminiImage, + GeminiInputFiles, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "GeminiNode": "Google Gemini", - "GeminiImageNode": "Google Gemini Image", - "GeminiInputFiles": "Gemini Input Files", -} + +async def comfy_entrypoint() -> GeminiExtension: + return GeminiExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index ab96760cb..0cca2b59b 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -18,6 +18,7 @@ from .conversions import ( tensor_to_bytesio, tensor_to_pil, trim_video, + video_to_base64_string, ) from .download_helpers import ( download_url_as_bytesio, @@ -73,6 +74,7 @@ __all__ = [ "tensor_to_bytesio", "tensor_to_pil", "trim_video", + "video_to_base64_string", # Validation utilities "get_number_of_images", "validate_aspect_ratio_closeness", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 10cd1051b..9f4c90c5c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -12,6 +12,7 @@ from PIL import Image from comfy.utils import common_upscale from comfy_api.latest import Input, InputImpl +from comfy_api.util import VideoContainer, VideoCodec from ._helpers import mimetype_to_extension @@ -173,6 +174,30 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co return base64.b64encode(audio_bytes).decode("utf-8") +def video_to_base64_string( + video: Input.Video, + container_format: VideoContainer = None, + codec: VideoCodec = None +) -> str: + """ + Converts a video input to a base64 string. + + Args: + video: The video input to convert + container_format: Optional container format to use (defaults to video.container if available) + codec: Optional codec to use (defaults to video.codec if available) + """ + video_bytes_io = BytesIO() + + # Use provided format/codec if specified, otherwise use video's own if available + format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) + codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) + + video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video_bytes_io.seek(0) + return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") + + def audio_ndarray_to_bytesio( audio_data_np: np.ndarray, sample_rate: int, From 098a352f136c610071bcb74f13e5b0ca16e6e7b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Oct 2025 17:05:22 -0700 Subject: [PATCH 011/231] Add warning for torch-directml usage (#10482) Added a warning message about the state of torch-directml. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index cf015a29a..afe78f36e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -89,6 +89,7 @@ if args.deterministic: directml_enabled = False if args.directml is not None: + logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") import torch_directml directml_enabled = True device_index = args.directml From f6bbc1ac846b7d9a73ae50c3a45cf5a41058c54d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Oct 2025 20:07:29 -0700 Subject: [PATCH 012/231] Fix mistake. (#10484) --- comfy/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sample.py b/comfy/sample.py index b1395da84..2f8f3a51c 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -17,7 +17,7 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] - noises = torch.cat(noises, axis=0) + return torch.cat(noises, axis=0) def prepare_noise(latent_image, seed, noise_inds=None): """ From 9d529e53084bdec28f684f3886a26c93598e7338 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sun, 26 Oct 2025 08:51:06 +0200 Subject: [PATCH 013/231] fix(api-nodes): random issues on Windows by capturing general OSError for retries (#10486) --- comfy_api_nodes/util/client.py | 15 +++++---------- comfy_api_nodes/util/download_helpers.py | 6 +++--- comfy_api_nodes/util/upload_helpers.py | 4 ++-- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 5833b118f..9c036d64b 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -2,7 +2,6 @@ import asyncio import contextlib import json import logging -import socket import time import uuid from dataclasses import dataclass @@ -456,24 +455,20 @@ async def _diagnose_connectivity() -> dict[str, bool]: results = { "internet_accessible": False, "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, } timeout = aiohttp.ClientTimeout(total=5.0) async with aiohttp.ClientSession(timeout=timeout) as session: - try: + with contextlib.suppress(ClientError, OSError): async with session.get("https://www.google.com") as resp: results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True + if not results["internet_accessible"]: return results parsed = urlparse(default_base_url()) health_url = f"{parsed.scheme}://{parsed.netloc}/health" - with contextlib.suppress(ClientError, asyncio.TimeoutError): + with contextlib.suppress(ClientError, OSError): async with session.get(health_url) as resp: results["api_accessible"] = resp.status < 500 - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] return results @@ -790,7 +785,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): except ProcessingInterrupted: logging.debug("Polling was interrupted by user") raise - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + except (ClientError, OSError) as e: if attempt <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", @@ -824,7 +819,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): delay *= cfg.retry_backoff continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: try: request_logger.log_request_response( operation_id=operation_id, diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 791dd5a50..f89045e12 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -32,7 +32,7 @@ async def download_url_to_bytesio( dest: Optional[Union[BytesIO, IO[bytes], str, Path]], *, timeout: Optional[float] = None, - max_retries: int = 3, + max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, cls: type[COMFY_IO.ComfyNode] = None, @@ -177,7 +177,7 @@ async def download_url_to_bytesio( return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None - except (ClientError, asyncio.TimeoutError) as e: + except (ClientError, OSError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): request_logger.log_request_response( @@ -191,7 +191,7 @@ async def download_url_to_bytesio( continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index a345d451d..7bfc61704 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -290,7 +290,7 @@ async def upload_file( return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None - except (aiohttp.ClientError, asyncio.TimeoutError) as e: + except (aiohttp.ClientError, OSError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): request_logger.log_request_response( @@ -313,7 +313,7 @@ async def upload_file( continue diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): + if not diag["internet_accessible"]: raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e From c170fd2db598a0bdce56f80e22e83e10ad731421 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 26 Oct 2025 17:23:01 -0700 Subject: [PATCH 014/231] Bump portable deps workflow to torch cu130 python 3.13.9 (#10493) --- .github/workflows/windows_release_dependencies.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f1e2946e6..f61ee21a2 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,7 +17,7 @@ on: description: 'cuda version' required: true type: string - default: "129" + default: "130" python_minor: description: 'python minor version' @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "9" # push: # branches: # - master From 601ee1775a3c06c9b4de1fa7d808af8625b2fcd5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 27 Oct 2025 20:54:00 -0700 Subject: [PATCH 015/231] Add a bat to run comfyui portable without api nodes. (#10504) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat new file mode 100644 index 000000000..cfe4b9f0e --- /dev/null +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -0,0 +1,2 @@ +..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +pause From c305deed56a6ed259563b2047d9fcd51471e6590 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 28 Oct 2025 13:24:16 +0800 Subject: [PATCH 016/231] Update template to 0.2.3 (#10503) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8570c66b6..121301669 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.2.2 +comfyui-workflow-templates==0.2.3 comfyui-embedded-docs==0.3.0 torch torchsde From 55bad303754eb60fa98f3ccf598e95502b819149 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 28 Oct 2025 07:25:29 +0200 Subject: [PATCH 017/231] feat(api-nodes): add LTXV API nodes (#10496) --- comfy_api_nodes/nodes_ltxv.py | 191 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 192 insertions(+) create mode 100644 comfy_api_nodes/nodes_ltxv.py diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py new file mode 100644 index 000000000..e6ad6e27a --- /dev/null +++ b/comfy_api_nodes/nodes_ltxv.py @@ -0,0 +1,191 @@ +from io import BytesIO +from typing import Optional + +import torch +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op_raw, + upload_images_to_comfyapi, + validate_string, +) + +MODELS_MAP = { + "LTX-2 (Pro)": "ltx-2-pro", + "LTX-2 (Fast)": "ltx-2-fast", +} + + +class ExecuteTaskRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + duration: int = Field(...) + resolution: str = Field(...) + fps: Optional[int] = Field(25) + generate_audio: Optional[bool] = Field(True) + image_uri: Optional[str] = Field(None) + + +class TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiTextToVideo", + display_name="LTXV Text To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution.", + inputs=[ + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), + data=ExecuteTaskRequest( + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(VideoFromFile(BytesIO(response))) + + +class ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiImageToVideo", + display_name="LTXV Image To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution based on start image.", + inputs=[ + IO.Image.Input("image", tooltip="First frame to be used for the video."), + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"), + data=ExecuteTaskRequest( + image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0], + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(VideoFromFile(BytesIO(response))) + + +class LtxvApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TextToVideoNode, + ImageToVideoNode, + ] + + +async def comfy_entrypoint() -> LtxvApiExtension: + return LtxvApiExtension() diff --git a/nodes.py b/nodes.py index 7cfa8ca14..12e365ca9 100644 --- a/nodes.py +++ b/nodes.py @@ -2349,6 +2349,7 @@ async def init_builtin_api_nodes(): "nodes_kling.py", "nodes_bfl.py", "nodes_bytedance.py", + "nodes_ltxv.py", "nodes_luma.py", "nodes_recraft.py", "nodes_pixverse.py", From 6abc30aae9bd13f31dafd32552a365f2df2cf715 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 28 Oct 2025 13:56:30 +0800 Subject: [PATCH 018/231] Update template to 0.2.4 (#10505) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 121301669..cc3d4ca94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.2.3 +comfyui-workflow-templates==0.2.4 comfyui-embedded-docs==0.3.0 torch torchsde From 614b8d3345424481d94a22fe7496d908c1a5c526 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 28 Oct 2025 00:01:13 -0700 Subject: [PATCH 019/231] frontend bump to 1.28.8 (#10506) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cc3d4ca94..4d84b0d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.28.7 +comfyui-frontend-package==1.28.8 comfyui-workflow-templates==0.2.4 comfyui-embedded-docs==0.3.0 torch From f2bb3230b796f6a486894fc3b597db2c0b9538c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Oct 2025 03:03:59 -0400 Subject: [PATCH 020/231] ComfyUI version v0.3.67 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 33a06bbb0..db48b05c4 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.66" +__version__ = "0.3.67" diff --git a/pyproject.toml b/pyproject.toml index fcc4854a5..ab054355c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.66" +version = "0.3.67" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From b61a40cbc9c2eb648b4d22bb513ed3ab2e2f0fd7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 00:21:45 -0700 Subject: [PATCH 021/231] Bump stable portable to cu130 python 3.13.9 (#10508) --- .github/workflows/release-stable-all.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 5c1024599..7dca7277b 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -18,9 +18,9 @@ jobs: uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "cu129" + cache_tag: "cu130" python_minor: "13" - python_patch: "6" + python_patch: "9" rel_name: "nvidia" rel_extra_name: "" test_release: true From 8cf2ba4ba64203551276513068ee81145e90f0bc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 00:23:52 -0700 Subject: [PATCH 022/231] Remove comfy api key from queue api. (#10502) --- execution.py | 8 +++----- main.py | 11 +++++++++-- server.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 78c36a4b0..b14bb14c7 100644 --- a/execution.py +++ b/execution.py @@ -1116,7 +1116,7 @@ class PromptQueue: messages: List[str] def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus']): + status: Optional['PromptQueue.ExecutionStatus'], process_item=None): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -1126,10 +1126,8 @@ class PromptQueue: if status is not None: status_dict = copy.deepcopy(status._asdict()) - # Remove sensitive data from extra_data before storing in history - for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS: - if sensitive_val in prompt[3]: - prompt[3].pop(sensitive_val) + if process_item is not None: + prompt = process_item(prompt) self.history[prompt[1]] = { "prompt": prompt, diff --git a/main.py b/main.py index 4b4c5dcc4..8d466d2eb 100644 --- a/main.py +++ b/main.py @@ -192,14 +192,21 @@ def prompt_worker(q, server_instance): prompt_id = item[1] server_instance.last_prompt_id = prompt_id - e.execute(item[2], prompt_id, item[3], item[4]) + sensitive = item[5] + extra_data = item[3].copy() + for k in sensitive: + extra_data[k] = sensitive[k] + + e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True + + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] q.task_done(item_id, e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) diff --git a/server.py b/server.py index fe58db286..5d773b10a 100644 --- a/server.py +++ b/server.py @@ -691,8 +691,9 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + remove_sensitive = lambda queue: [x[:5] for x in queue] + queue_info['queue_running'] = remove_sensitive(current_queue[0]) + queue_info['queue_pending'] = remove_sensitive(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") @@ -728,7 +729,11 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + sensitive = {} + for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: + if sensitive_val in extra_data: + sensitive[sensitive_val] = extra_data.pop(sensitive_val) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: From 3bea4efc6b23d76c6b0672cd90421a9024e13fdb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 01:45:45 -0700 Subject: [PATCH 023/231] Tell users to update nvidia drivers if problem with portable. (#10510) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 434d4ff06..4204777e9 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you If you have trouble extracting it, right click the file -> properties -> unblock +Update your Nvidia drivers if it doesn't start. + #### Alternative Downloads: [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) From 22e40d2ace0f53da025b3a41cbe4b664ef807097 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 12:08:08 -0700 Subject: [PATCH 024/231] Tell users to update their nvidia drivers if portable doesn't start. (#10518) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 1 + .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 1 + .../run_nvidia_gpu_fast_fp16_accumulation.bat | 1 + 3 files changed, 3 insertions(+) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index cfe4b9f0e..ed00583b6 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,2 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 274d7c948..4898a424f 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 38f06ecb2..32611e4af 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause From 8817f8fc148c5a63ffd3f854975df8e72c740540 Mon Sep 17 00:00:00 2001 From: contentis Date: Tue, 28 Oct 2025 21:20:53 +0100 Subject: [PATCH 025/231] Mixed Precision Quantization System (#10498) * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Fix missing keys * Rename quant dtype parameter * Rename quant dtype parameter * Fix unittests for CPU build --- comfy/model_base.py | 10 +- comfy/model_detection.py | 20 + comfy/ops.py | 146 +++++- comfy/quant_ops.py | 437 ++++++++++++++++++ comfy/sd.py | 13 +- comfy/supported_models_base.py | 1 + .../comfy_quant/test_mixed_precision.py | 232 ++++++++++ tests-unit/comfy_quant/test_quant_registry.py | 190 ++++++++ 8 files changed, 1030 insertions(+), 19 deletions(-) create mode 100644 comfy/quant_ops.py create mode 100644 tests-unit/comfy_quant/test_mixed_precision.py create mode 100644 tests-unit/comfy_quant/test_quant_registry.py diff --git a/comfy/model_base.py b/comfy/model_base.py index e877f19ac..7c788d085 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 141f1e164..3142a7fc3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,20 @@ import math import logging import torch + +def detect_layer_quantization(metadata): + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError("Invalid quantization metadata format") + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(metadata) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/ops.py b/comfy/ops.py index 934e21261..93731eedf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -344,6 +344,10 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -355,9 +359,9 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input @@ -368,23 +372,18 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) - else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -478,7 +477,128 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} + +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + + def reset_parameters(self): + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + manually_loaded_keys = [weight_key] + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] + + scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(scale_key) + + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + MixedPrecisionOps._compute_dtype = compute_dtype + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..b14e03084 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,437 @@ +import torch +import logging +from typing import Tuple, Dict + +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} + + +def register_layout_op(torch_op, layout_type): + """ + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator + + +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None + + +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + + +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError("TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) + """ + + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params + + def __repr__(self): + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_qdata"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) + + @classmethod + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) + + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) + + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + def dequant_arg(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) + + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + lp_amax = torch.finfo(dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + + try: + output = torch._scaled_mm( + plain_input.reshape(-1, input_shape[2]), + weight_t, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + else: + return output + + except Exception as e: + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..6411bb27d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if hasattr(model_config, "layer_quant_config"): + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): @@ -1346,8 +1349,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..e4bd74514 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py new file mode 100644 index 000000000..267bc177b --- /dev/null +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -0,0 +1,232 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class SimpleModel(torch.nn.Module): + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + }, + "layer3": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E4M3FN + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict (strict=False because custom loading pops keys) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + + # Verify scales were loaded + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) + + # Save state dict + state_dict2 = model.state_dict() + + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Add a weight function (simulating LoRA) + # This should trigger dequantization during forward pass + def apply_lora(weight): + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA (triggers weight_function path) + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats raise error""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS + model = SimpleModel(operations=ops.MixedPrecisionOps) + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..477811029 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,190 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() From d202c2ba7404affd58a2199aeb514b3cc48e0ef3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 29 Oct 2025 06:22:08 +1000 Subject: [PATCH 026/231] execution: Allow a subgraph nodes to execute multiple times (#10499) In the case of --cache-none lazy and subgraph execution can cause anything to be run multiple times per workflow. If that rerun nodes is in itself a subgraph generator, this will crash for two reasons. pending_subgraph_results[] does not cleanup entries after their use. So when a pending_subgraph_result is consumed, remove it from the list so that if the corresponding node is fully re-executed this misses lookup and it fall through to execute the node as it should. Secondly, theres is an explicit enforcement against dups in the addition of subgraphs nodes as ephemerals to the dymprompt. Remove this enforcement as the use case is now valid. --- execution.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/execution.py b/execution.py index b14bb14c7..20e106213 100644 --- a/execution.py +++ b/execution.py @@ -445,6 +445,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, resolved_outputs.append(tuple(resolved_output)) output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] + del pending_subgraph_results[unique_id] has_subgraph = False else: get_progress_state().start_progress(unique_id) @@ -527,10 +528,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if new_graph is None: cached_outputs.append((False, node_outputs)) else: - # Check for conflicts - for node_id in new_graph.keys(): - if dynprompt.has_node(node_id): - raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) From 210f7a1ba580d57d817ca68346cb72b8d0a26ad2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:38:05 +0200 Subject: [PATCH 027/231] convert nodes_recraft.py to V3 schema (#10507) --- comfy_api_nodes/nodes_recraft.py | 1319 +++++++++++++----------------- 1 file changed, 585 insertions(+), 734 deletions(-) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8ee7e55c4..dee186cd6 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,82 +1,71 @@ -from __future__ import annotations -from inspect import cleandoc -from typing import Optional +from io import BytesIO +from typing import Optional, Union + +import aiohttp +import torch +from PIL import UnidentifiedImageError +from typing_extensions import override + from comfy.utils import ProgressBar -from comfy_extras.nodes_images import SVG # Added -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, +) from comfy_api_nodes.apis.recraft_api import ( - RecraftImageGenerationRequest, - RecraftImageGenerationResponse, - RecraftImageSize, - RecraftModel, - RecraftStyle, - RecraftStyleV3, RecraftColor, RecraftColorChain, RecraftControls, + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, RecraftIO, + RecraftModel, + RecraftStyle, + RecraftStyleV3, get_v3_substyles, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - EmptyRequest, + bytesio_to_image_tensor, + download_url_as_bytesio, + sync_op, + tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, - resize_mask_to_image, -) -from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor -from server import PromptServer - -import torch -from io import BytesIO -from PIL import UnidentifiedImageError -import aiohttp +from comfy_extras.nodes_images import SVG async def handle_recraft_file_request( + cls: type[IO.ComfyNode], image: torch.Tensor, path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, + mask: Optional[torch.Tensor] = None, + total_pixels: int = 4096 * 4096, + timeout: int = 1024, request=None, - auth_kwargs: dict[str,str] = None, ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() + """Handle sending common Recraft file-only request to get back file bytes.""" - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } + files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=path, method="POST"), + response_model=RecraftImageGenerationResponse, + data=request if request else None, files=files, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -84,11 +73,11 @@ async def handle_recraft_file_request( def recraft_multipart_parser( data, parent_key=None, - formatter: callable = None, - converted_to_check: list[list] = None, + formatter: Optional[type[callable]] = None, + converted_to_check: Optional[list[list]] = None, is_list: bool = False, - return_mode: str = "formdata" # "dict" | "formdata" -) -> dict | aiohttp.FormData: + return_mode: str = "formdata", # "dict" | "formdata" +) -> Union[dict, aiohttp.FormData]: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -108,8 +97,8 @@ def recraft_multipart_parser( # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]): - # if list already exists exists, just extend list with data + def handle_converted_lists(item, parent_key, lists_to_check=list[list]): + # if list already exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): @@ -125,7 +114,7 @@ def recraft_multipart_parser( formatter = lambda v: v # Multipart representation of value if not isinstance(data, dict): - # if list already exists exists, just extend list with data + # if list already exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: return {} @@ -146,7 +135,9 @@ def recraft_multipart_parser( elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" - converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + converted.extend( + recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items() + ) else: converted.append((current_key, formatter(value))) @@ -166,6 +157,7 @@ class handle_recraft_image_output: """ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. """ + def __init__(self): pass @@ -174,243 +166,225 @@ class handle_recraft_image_output: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and exc_type is UnidentifiedImageError: - raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + raise Exception( + "Received output data was not an image; likely an SVG. " + "If you used style_id, make sure it is not a Vector art style." + ) -class RecraftColorRGBNode: - """ - Create Recraft Color by choosing specific RGB values. - """ - - RETURN_TYPES = (RecraftIO.COLOR,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - RETURN_NAMES = ("recraft_color",) - FUNCTION = "create_color" - CATEGORY = "api node/image/Recraft" +class RecraftColorRGBNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftColorRGB", + display_name="Recraft Color RGB", + category="api node/image/Recraft", + description="Create Recraft Color by choosing specific RGB values.", + inputs=[ + IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), + IO.Int.Input("g", default=0, min=0, max=255, tooltip="Green value of color."), + IO.Int.Input("b", default=0, min=0, max=255, tooltip="Blue value of color."), + IO.Custom(RecraftIO.COLOR).Input("recraft_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.COLOR).Output(display_name="recraft_color"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "r": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Red value of color." - }), - "g": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Green value of color." - }), - "b": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Blue value of color." - }), - }, - "optional": { - "recraft_color": (RecraftIO.COLOR,), - } - } - - def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + def execute(cls, r: int, g: int, b: int, recraft_color: RecraftColorChain = None) -> IO.NodeOutput: recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() recraft_color.add(RecraftColor(r, g, b)) - return (recraft_color, ) + return IO.NodeOutput(recraft_color) -class RecraftControlsNode: - """ - Create Recraft Controls for customizing Recraft generation. - """ - - RETURN_TYPES = (RecraftIO.CONTROLS,) - RETURN_NAMES = ("recraft_controls",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_controls" - CATEGORY = "api node/image/Recraft" +class RecraftControlsNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftControls", + display_name="Recraft Controls", + category="api node/image/Recraft", + description="Create Recraft Controls for customizing Recraft generation.", + inputs=[ + IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), + IO.Custom(RecraftIO.COLOR).Input("background_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.CONTROLS).Output(display_name="recraft_controls"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "colors": (RecraftIO.COLOR,), - "background_color": (RecraftIO.COLOR,), - } - } - - def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): - return (RecraftControls(colors=colors, background_color=background_color), ) + def execute(cls, colors: RecraftColorChain = None, background_color: RecraftColorChain = None) -> IO.NodeOutput: + return IO.NodeOutput(RecraftControls(colors=colors, background_color=background_color)) -class RecraftStyleV3RealisticImageNode: - """ - Select realistic_image style and optional substyle. - """ - - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" - +class RecraftStyleV3RealisticImageNode(IO.ComfyNode): RECRAFT_STYLE = RecraftStyleV3.realistic_image @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), - } - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3RealisticImage", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) - def create_style(self, substyle: str): + @classmethod + def execute(cls, substyle: str) -> IO.NodeOutput: if substyle == "None": substyle = None - return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + return IO.NodeOutput(RecraftStyle(cls.RECRAFT_STYLE, substyle)) class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select digital_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.digital_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3DigitalIllustration", + display_name="Recraft Style - Digital Illustration", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.vector_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3VectorIllustrationNode", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), - } - } - RECRAFT_STYLE = RecraftStyleV3.logo_raster + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3LogoRaster", + display_name="Recraft Style - Logo Raster", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) -class RecraftStyleInfiniteStyleLibrary: - """ - Select style based on preexisting UUID from Recraft's Infinite Style Library. - """ - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" +class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3InfiniteStyleLibrary", + display_name="Recraft Style - Infinite Style Library", + category="api node/image/Recraft", + description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + inputs=[ + IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "style_id": (IO.STRING, { - "default": "", - "tooltip": "UUID of style from Infinite Style Library.", - }) - } - } - - def create_style(self, style_id: str): + def execute(cls, style_id: str) -> IO.NodeOutput: if not style_id: raise Exception("The style_id input cannot be empty.") - return (RecraftStyle(style_id=style_id),) + return IO.NodeOutput(RecraftStyle(style_id=style_id)) -class RecraftTextToImageNode: - """ - Generates images synchronously based on prompt and resolution. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToImageNode", + display_name="Recraft Text to Image", + category="api node/image/Recraft", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, size: str, n: int, @@ -418,9 +392,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -433,14 +405,11 @@ class RecraftTextToImageNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -451,109 +420,83 @@ class RecraftTextToImageNode: style_id=recraft_style.style_id, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() images = [] - urls = [] for data in response.data: with handle_recraft_image_output(): - if unique_id and data.url: - urls.append(data.url) - urls_string = '\n'.join(urls) - PromptServer.instance.send_progress_text( - f"Result URL: {urls_string}", unique_id - ) - image = bytesio_to_image_tensor( - await download_url_to_bytesio(data.url, timeout=1024) - ) + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) - output_image = torch.cat(images, dim=0) - return (output_image,) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageToImageNode: - """ - Modify image based on prompt and strength. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageToImageNode", + display_name="Recraft Image to Image", + category="api node/image/Recraft", + description="Modify image based on prompt and strength.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Defines the difference with the original image, should lie in [0, 1], " + "where 0 means almost identical, and 1 means miserable similarity.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "strength": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." - } - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, @@ -562,8 +505,7 @@ class RecraftImageToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -593,83 +535,69 @@ class RecraftImageToImageNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/imageToImage", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageInpaintingNode: - """ - Modify image based on prompt and mask. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageInpaintingNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageInpaintingNode", + display_name="Recraft Image Inpainting", + category="api node/image/Recraft", + description="Modify image based on prompt and mask.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "mask": (IO.MASK, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -677,8 +605,7 @@ class RecraftImageInpaintingNode: seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -705,96 +632,73 @@ class RecraftImageInpaintingNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - mask=mask[i:i+1], + mask=mask[i : i + 1], path="/proxy/recraft/images/inpaint", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftTextToVectorNode: - """ - Generates SVG synchronously based on prompt and resolution. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToVectorNode", + display_name="Recraft Text to Vector", + category="api node/image/Recraft", + description="Generates SVG synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), + IO.Combo.Input("substyle", options=get_v3_substyles(RecraftStyleV3.vector_illustration)), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, substyle: str, size: str, @@ -802,9 +706,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) @@ -816,14 +718,11 @@ class RecraftTextToVectorNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -833,139 +732,105 @@ class RecraftTextToVectorNode: substyle=recraft_style.substyle, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] - urls = [] for data in response.data: - if unique_id and data.url: - urls.append(data.url) - # Print result on each iteration in case of error - PromptServer.instance.send_progress_text( - f"Result URL: {' '.join(urls)}", unique_id - ) - svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) - return (SVG(svg_data),) + return IO.NodeOutput(SVG(svg_data)) -class RecraftVectorizeImageNode: - """ - Generates SVG synchronously from an input image. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftVectorizeImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftVectorizeImageNode", + display_name="Recraft Vectorize Image", + category="api node/image/Recraft", + description="Generates SVG synchronously from an input image.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: svgs = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/vectorize", - auth_kwargs=kwargs, ) svgs.append(SVG(sub_bytes)) pbar.update(1) - return (SVG.combine_all(svgs), ) + return IO.NodeOutput(SVG.combine_all(svgs)) -class RecraftReplaceBackgroundNode: - """ - Replace background on image, based on provided prompt. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftReplaceBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftReplaceBackgroundNode", + display_name="Recraft Replace Background", + category="api node/image/Recraft", + description="Replace background on image, based on provided prompt.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", tooltip="Prompt for the image generation.", default="", multiline=True), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -988,165 +853,151 @@ class RecraftReplaceBackgroundNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftRemoveBackgroundNode: - """ - Remove background from image, and return processed image and mask. - """ - - RETURN_TYPES = (IO.IMAGE, IO.MASK) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftRemoveBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftRemoveBackgroundNode", + display_name="Recraft Remove Background", + category="api node/image/Recraft", + description="Remove background from image, and return processed image and mask.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + IO.Mask.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/removeBackground", - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) images_tensor = torch.cat(images, dim=0) # use alpha channel as masks, in B,H,W format - masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) - return (images_tensor, masks_tensor) + masks_tensor = images_tensor[:, :, :, -1:].squeeze(-1) + return IO.NodeOutput(images_tensor, masks_tensor) -class RecraftCrispUpscaleNode: - """ - Upscale image synchronously. - Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - +class RecraftCrispUpscaleNode(IO.ComfyNode): RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftCrispUpscaleNode", + display_name="Recraft Crisp Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘crisp upscale’ tool, " + "increasing image resolution, making the image sharper and cleaner.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + @classmethod + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - path=self.RECRAFT_PATH, - auth_kwargs=kwargs, + path=cls.RECRAFT_PATH, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor,) + return IO.NodeOutput(torch.cat(images, dim=0)) class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - """ - Upscale image synchronously. - Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreativeUpscaleNode", + display_name="Recraft Creative Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘creative upscale’ tool, " + "boosting resolution with a focus on refining small details and faces.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "RecraftTextToImageNode": RecraftTextToImageNode, - "RecraftImageToImageNode": RecraftImageToImageNode, - "RecraftImageInpaintingNode": RecraftImageInpaintingNode, - "RecraftTextToVectorNode": RecraftTextToVectorNode, - "RecraftVectorizeImageNode": RecraftVectorizeImageNode, - "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, - "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, - "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, - "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, - "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, - "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, - "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, - "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, - "RecraftColorRGB": RecraftColorRGBNode, - "RecraftControls": RecraftControlsNode, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "RecraftTextToImageNode": "Recraft Text to Image", - "RecraftImageToImageNode": "Recraft Image to Image", - "RecraftImageInpaintingNode": "Recraft Image Inpainting", - "RecraftTextToVectorNode": "Recraft Text to Vector", - "RecraftVectorizeImageNode": "Recraft Vectorize Image", - "RecraftRemoveBackgroundNode": "Recraft Remove Background", - "RecraftReplaceBackgroundNode": "Recraft Replace Background", - "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", - "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", - "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", - "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", - "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", - "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", - "RecraftColorRGB": "Recraft Color RGB", - "RecraftControls": "Recraft Controls", -} +class RecraftExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RecraftTextToImageNode, + RecraftImageToImageNode, + RecraftImageInpaintingNode, + RecraftTextToVectorNode, + RecraftVectorizeImageNode, + RecraftRemoveBackgroundNode, + RecraftReplaceBackgroundNode, + RecraftCrispUpscaleNode, + RecraftCreativeUpscaleNode, + RecraftStyleV3RealisticImageNode, + RecraftStyleV3DigitalIllustrationNode, + RecraftStyleV3LogoRasterNode, + RecraftStyleInfiniteStyleLibrary, + RecraftColorRGBNode, + RecraftControlsNode, + ] + + +async def comfy_entrypoint() -> RecraftExtension: + return RecraftExtension() From 3fa7a5c04ae69ad168a875e8d3d453783d60899d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:21:01 -0700 Subject: [PATCH 028/231] Speed up offloading using pinned memory. (#10526) To enable this feature use: --fast pinned_memory --- comfy/cli_args.py | 1 + comfy/model_management.py | 30 ++++++++++++++++++++++++++++++ comfy/model_patcher.py | 26 +++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..001abd843 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) diff --git a/comfy/model_management.py b/comfy/model_management.py index afe78f36e..3e5b977d4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def pin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + return True + + return False + +def unpin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + return True + + return False + def sage_attention_enabled(): return args.use_sage_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..aec73349c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -238,6 +238,7 @@ class ModelPatcher: self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} @@ -618,6 +619,21 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if comfy.model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + comfy.model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -683,6 +699,8 @@ class ModelPatcher: patch_counter += 1 cast_weight = True + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -713,7 +731,9 @@ class ModelPatcher: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -762,6 +782,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -857,6 +878,9 @@ class ModelPatcher: memory_freed += module_mem logging.debug("freed {}".format(n)) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed From e525673f7201b6c49af0fa0e6baf44e4e98bb10c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:37:00 -0700 Subject: [PATCH 029/231] Fix issue. (#10527) --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6411bb27d..de4eee96e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1330,7 +1330,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if hasattr(model_config, "layer_quant_config"): + if model_config.layer_quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) From 6c14f3afac0ea28dba24fe8783e7c1f09c03b31f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 29 Oct 2025 20:14:56 +0200 Subject: [PATCH 030/231] use new API client in Luma and Minimax nodes (#10528) --- comfy_api_nodes/apinode_utils.py | 81 ------ comfy_api_nodes/apis/minimax_api.py | 120 ++++++++ comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_luma.py | 354 ++++++----------------- comfy_api_nodes/nodes_minimax.py | 237 +++++---------- comfy_api_nodes/util/client.py | 2 +- comfy_api_nodes/util/download_helpers.py | 3 +- 7 files changed, 283 insertions(+), 516 deletions(-) create mode 100644 comfy_api_nodes/apis/minimax_api.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 4182c8f80..6a72b9d1d 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -3,14 +3,6 @@ import aiohttp import mimetypes from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api_nodes.apis.client import ( - ApiClient, - ApiEndpoint, - HttpMethod, - SynchronousOperation, - UploadRequest, - UploadResponse, -) from server import PromptServer from comfy.cli_args import args @@ -19,7 +11,6 @@ from PIL import Image import torch import math import base64 -from .util import tensor_to_bytesio, bytesio_to_image_tensor from io import BytesIO @@ -148,11 +139,6 @@ async def download_url_to_bytesio( return BytesIO(await resp.read()) -def process_image_response(response_content: bytes | str) -> torch.Tensor: - """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response_content)) - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -169,73 +155,6 @@ def text_filepath_to_data_uri(filepath: str) -> str: return f"data:{mime_type};base64,{base64_string}" -async def upload_file_to_comfyapi( - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: Optional[str], - auth_kwargs: Optional[dict[str, str]] = None, -) -> str: - """ - Uploads a single file to ComfyUI API and returns its download URL. - - Args: - file_bytes_io: BytesIO object containing the file data. - filename: The filename of the file. - upload_mime_type: MIME type of the file. - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded file. - """ - if upload_mime_type is None: - request_object = UploadRequest(file_name=filename) - else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - - response: UploadResponse = await operation.execute() - await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) - return response.download_url - - -async def upload_images_to_comfyapi( - image: torch.Tensor, - max_images=8, - auth_kwargs: Optional[dict[str, str]] = None, - mime_type: Optional[str] = None, -) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - - Args: - image: Input torch.Tensor image. - max_images: Maximum number of images to upload. - auth_kwargs: Optional authentication token(s). - mime_type: Optional MIME type for the image. - """ - # if batch, try to upload each file if max_images is greater than 0 - download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - - for idx in range(min(batch_len, max_images)): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) - download_urls.append(url) - return download_urls - - def resize_mask_to_image( mask: torch.Tensor, image: torch.Tensor, diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax_api.py new file mode 100644 index 000000000..d747e177a --- /dev/null +++ b/comfy_api_nodes/apis/minimax_api.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class File(BaseModel): + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + backup_download_url: Optional[str] = Field( + None, description='The backup URL to download the video' + ) + + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + + +class MinimaxFileRetrieveResponse(BaseModel): + base_resp: MinimaxBaseResponse + file: File + + +class MiniMaxModel(str, Enum): + T2V_01_Director = 'T2V-01-Director' + I2V_01_Director = 'I2V-01-Director' + S2V_01 = 'S2V-01' + I2V_01 = 'I2V-01' + I2V_01_live = 'I2V-01-live' + T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' + + +class Status6(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + base_resp: MinimaxBaseResponse + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + status: Status6 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + task_id: str = Field(..., description='The task ID being queried.') + + +class SubjectReferenceItem(BaseModel): + image: Optional[str] = Field( + None, description='URL or base64 encoding of the subject reference image.' + ) + mask: Optional[str] = Field( + None, + description='URL or base64 encoding of the mask for the subject reference image.', + ) + + +class MinimaxVideoGenerationRequest(BaseModel): + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) + model: MiniMaxModel = Field( + ..., + description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', + ) + prompt: Optional[str] = Field( + None, + description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].', + max_length=2000, + ) + prompt_optimizer: Optional[bool] = Field( + True, + description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', + ) + subject_reference: Optional[list[SubjectReferenceItem]] = Field( + None, + description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', + ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) + + +class MinimaxVideoGenerationResponse(BaseModel): + base_resp: MinimaxBaseResponse + task_id: str = Field( + ..., description='The task ID for the asynchronous video generation task.' + ) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 9eae5f11a..d8fd3378b 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -20,9 +20,9 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( download_url_to_bytesio, - bytesio_to_image_tensor, resize_mask_to_image, ) +from comfy_api_nodes.util import bytesio_to_image_tensor from server import PromptServer V1_V1_RES_MAP = { diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index e74441e5e..894f2b08c 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,69 +1,51 @@ -from __future__ import annotations -from inspect import cleandoc from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.luma_api import ( - LumaImageModel, - LumaVideoModel, - LumaVideoOutputResolution, - LumaVideoModelOutputDuration, LumaAspectRatio, - LumaState, - LumaImageGenerationRequest, - LumaGenerationRequest, - LumaGeneration, LumaCharacterRef, - LumaModifyImageRef, + LumaConceptChain, + LumaGeneration, + LumaGenerationRequest, + LumaImageGenerationRequest, LumaImageIdentity, + LumaImageModel, + LumaImageReference, + LumaIO, + LumaKeyframes, + LumaModifyImageRef, LumaReference, LumaReferenceChain, - LumaImageReference, - LumaKeyframes, - LumaConceptChain, - LumaIO, + LumaVideoModel, + LumaVideoModelOutputDuration, + LumaVideoOutputResolution, get_luma_concepts, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, - process_image_response, + validate_string, ) -from server import PromptServer -from comfy_api_nodes.util import validate_string - -import aiohttp -import torch -from io import BytesIO LUMA_T2V_AVERAGE_DURATION = 105 LUMA_I2V_AVERAGE_DURATION = 100 -def image_result_url_extractor(response: LumaGeneration): - return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None - -def video_result_url_extractor(response: LumaGeneration): - return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None class LumaReferenceNode(IO.ComfyNode): - """ - Holds an image and weight for use with Luma Generate Image node. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( "image", @@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod - def execute( - cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ) -> IO.NodeOutput: + def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: @@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode): class LumaConceptsNode(IO.ComfyNode): - """ - Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( "concept1", @@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod @@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode): class LumaImageGenerationNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode): aspect_ratio: str, seed, style_image_weight: float, - image_luma_ref: LumaReferenceChain = None, - style_image: torch.Tensor = None, - character_image: torch.Tensor = None, + image_luma_ref: Optional[LumaReferenceChain] = None, + style_image: Optional[torch.Tensor] = None, + character_image: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = await cls._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs, - ) + api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = await cls._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=auth_kwargs, - ) + api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight) # handle character_ref images character_ref = None if character_image is not None: - download_urls = await upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=auth_kwargs, - ) - character_ref = LumaCharacterRef( - identity0=LumaImageIdentity(images=download_urls) - ) + download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4) + character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls)) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, aspect_ratio=aspect_ratio, @@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) @classmethod - async def _convert_luma_refs( - cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None - ): + async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = await upload_images_to_comfyapi( - ref.image, max_images=1, auth_kwargs=auth_kwargs - ) + download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1) luma_urls.append(download_urls[0]) ref_count += 1 if ref_count >= max_refs: @@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode): return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) @classmethod - async def _convert_style_image( - cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None - ): - chain = LumaReferenceChain( - first_ref=LumaReference(image=style_image, weight=weight) - ) - return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + async def _convert_style_image(cls, style_image: torch.Tensor, weight: float): + chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight)) + return await cls._convert_luma_refs(chain, max_refs=1) class LumaImageModifyNode(IO.ComfyNode): - """ - Modifies images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( "image", @@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode): image_weight: float, seed, ) -> IO.NodeOutput: - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # first, upload image - download_urls = await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, image, max_images=1) image_url = download_urls[0] - # next, make Luma call with download url provided - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, modify_image_ref=LumaModifyImageRef( - url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) + url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) class LumaTextToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): duration: str, loop: bool, seed, - luma_concepts: LumaConceptChain = None, + luma_concepts: Optional[LumaConceptChain] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, resolution=resolution, @@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) class LumaImageToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt, input images, and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( "prompt", @@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): luma_concepts: LumaConceptChain = None, ) -> IO.NodeOutput: if first_image is None and last_image is None: - raise Exception( - "At least one of first_image and last_image requires an input." - ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs) + raise Exception("At least one of first_image and last_image requires an input.") + keyframes = await cls._convert_to_keyframes(first_image, last_image) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason @@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) @classmethod async def _convert_to_keyframes( cls, first_image: torch.Tensor = None, last_image: torch.Tensor = None, - auth_kwargs: Optional[dict[str,str]] = None, ): if first_image is None and last_image is None: return None frame0 = None frame1 = None if first_image is not None: - download_urls = await upload_images_to_comfyapi( - first_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = await upload_images_to_comfyapi( - last_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1) frame1 = LumaImageReference(type="image", url=download_urls[0]) return LumaKeyframes(frame0=frame0, frame1=frame1) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index e3722e79b..05cbb700f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,71 +1,57 @@ -from inspect import cleandoc from typing import Optional -import logging -import torch +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.minimax_api import ( + MinimaxFileRetrieveResponse, + MiniMaxModel, + MinimaxTaskResultResponse, MinimaxVideoGenerationRequest, MinimaxVideoGenerationResponse, - MinimaxFileRetrieveResponse, - MinimaxTaskResultResponse, SubjectReferenceItem, - MiniMaxModel, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, + validate_string, ) -from comfy_api_nodes.util import validate_string -from server import PromptServer - I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 async def _generate_mm_video( + cls: type[IO.ComfyNode], *, - auth: dict[str, str], - node_id: str, prompt_text: str, seed: int, model: str, - image: Optional[torch.Tensor] = None, # used for ImageToVideo - subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo average_duration: Optional[int] = None, ) -> IO.NodeOutput: if image is None: validate_string(prompt_text, field_name="prompt_text") - # upload image, if passed in image_url = None if image is not None: - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] + subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -73,81 +59,50 @@ async def _generate_mm_video( subject_reference=subject_reference, prompt_optimizer=None, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=node_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if node_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, node_id) - - # Download and return as VideoFromFile - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxTextToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( "prompt_text", @@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode): class MinimaxImageToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "image", @@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode): class MinimaxSubjectToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "subject", @@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): class MinimaxHailuoVideoNode(IO.ComfyNode): - """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( "prompt_text", @@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): resolution: str = "768P", model: str = "MiniMax-Hailuo-02", ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") @@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0] - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): duration=duration, resolution=resolution, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") average_duration = 120 if resolution == "768P" else 240 - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=cls.hidden.unique_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if cls.hidden.unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, cls.hidden.unique_id) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxExtension(ComfyExtension): diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9c036d64b..9ae512fe5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -78,7 +78,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] -FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] +FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index f89045e12..364874bed 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -232,11 +232,12 @@ async def download_url_to_video_output( video_url: str, *, timeout: float = None, + max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, ) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() - await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) + await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) return VideoFromFile(result) From 1a58087ac2eb64be3645934d0025aafaa5bdce38 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:43:51 -0700 Subject: [PATCH 031/231] Reduce memory usage for fp8 scaled op. (#10531) --- comfy/quant_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b14e03084..fb35a0d40 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -358,7 +358,7 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = scale.to(device=tensor.device, dtype=torch.float32) lp_amax = torch.finfo(dtype).max - tensor_scaled = tensor.float() / scale + tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) From ec4fc2a09a390d0d81500c51fb9e4d8a7a5ce1fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:48:06 -0700 Subject: [PATCH 032/231] Fix case of weights not being unpinned. (#10533) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aec73349c..119119e96 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1283,5 +1283,6 @@ class ModelPatcher: self.clear_cached_hook_weights() def __del__(self): + self.unpin_all_weights() self.detach(unpatch_all=False) From ab7ab5be23fb9b71d1790f424e7dcf91dc1fe0cc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 30 Oct 2025 07:17:46 +1000 Subject: [PATCH 033/231] Fix Race condition in --async-offload that can cause corruption (#10501) * mm: factor out the current stream getter Make this a reusable function. * ops: sync the offload stream with the consumption of w&b This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...) * controlnet: adopt new cast_bias_weight synchronization scheme This is nessacary for safe async weight offloading. * mm: sync the last stream in the queue, not the next Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position. --- comfy/controlnet.py | 17 +++--- comfy/model_management.py | 28 +++++---- comfy/ops.py | 121 ++++++++++++++++++++++++++++---------- 3 files changed, 114 insertions(+), 52 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f08ff4b36..0b5e30f52 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -310,11 +310,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( @@ -350,12 +352,13 @@ class ControlLoraOps: def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e5b977d4..79c0dfdb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,6 +1013,16 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1021,21 +1031,17 @@ def get_offload_stream(device): if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] + #Sync the oldest stream in the queue with the current + ss[stream_counter].wait_stream(current_stream(device)) stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): @@ -1044,18 +1050,14 @@ def get_offload_stream(device): ss.append(torch.xpu.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: diff --git a/comfy/ops.py b/comfy/ops.py index 93731eedf..71ca7a2bd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + @torch.compiler.disable() -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: dtype = input.dtype @@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device - offload_stream = comfy.model_management.get_offload_stream(device) + if offloadable: + offload_stream = comfy.model_management.get_offload_stream(device) + else: + offload_stream = None + if offload_stream is not None: wf_context = offload_stream else: @@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) - return weight, bias + if offloadable: + return weight, bias, offload_stream + else: + #Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + if weight is not None: + device = weight.device + else: + if bias is None: + return + device = bias.device + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + class CastWeightBiasOp: comfy_cast_weights = False @@ -118,8 +143,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -133,8 +160,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -148,8 +177,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -172,8 +203,10 @@ class disable_weight_init: return super()._conv_forward(input, weight, bias, *args, **kwargs) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -187,8 +220,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -203,11 +238,14 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -223,11 +261,15 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -246,10 +288,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -268,10 +312,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -289,8 +335,11 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + def forward(self, *args, **kwargs): run_every_op() @@ -361,7 +410,7 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight scale_input = self.scale_input @@ -382,6 +431,8 @@ def fp8_linear(self, input): quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -404,8 +455,10 @@ class fp8_ops(manual_cast): except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) @@ -433,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None if out is not None: return out - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) if weight.numel() < input.numel(): #TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def convert_weight(self, weight, inplace=False, **kwargs): if inplace: @@ -577,8 +632,10 @@ class MixedPrecisionOps(disable_weight_init): return torch.nn.functional.linear(input, weight, bias) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, input, *args, **kwargs): run_every_op() From 25de7b1bfa22dd98922f047a1342cc97f8e46c5b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:20:27 -0700 Subject: [PATCH 034/231] Try to fix slow load issue on low ram hardware with pinned mem. (#10536) --- comfy/model_patcher.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 119119e96..74b9e48bc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -658,6 +658,7 @@ class ModelPatcher: loading = self._load_list() load_completely = [] + offloaded = [] loading.sort(reverse=True) for x in loading: n = x[1] @@ -699,8 +700,7 @@ class ModelPatcher: patch_counter += 1 cast_weight = True - for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + offloaded.append((module_mem, n, m, params)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -741,6 +741,12 @@ class ModelPatcher: for x in load_completely: x[2].to(device_to) + for x in offloaded: + n = x[1] + params = x[3] + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True From 906c0899575a83ac69bb095e835fdec748891da4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:29:01 -0700 Subject: [PATCH 035/231] Fix small performance regression with fp8 fast and scaled fp8. (#10537) --- comfy/ops.py | 6 +++++- comfy/quant_ops.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 71ca7a2bd..18f6b804b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -421,14 +421,18 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) else: scale_input = scale_input.to(input.device) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index fb35a0d40..c822fe53c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - lp_amax = torch.finfo(dtype).max tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' + # lp_amax = torch.finfo(dtype).max + # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { From 998bf60bebd03e57a55e106434657849342b733f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 29 Oct 2025 16:37:06 -0700 Subject: [PATCH 036/231] Add units/info for the numbers displayed on 'load completely' and 'load partially' log messages (#10538) --- comfy/model_patcher.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 74b9e48bc..ed3f3f5cb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -655,6 +655,7 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 loading = self._load_list() load_completely = [] @@ -675,6 +676,7 @@ class ModelPatcher: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue @@ -748,10 +750,10 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) From 163b629c70a349c7d1e91eebc5365713e770af8a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 30 Oct 2025 08:49:03 +0200 Subject: [PATCH 037/231] use new API client in Pixverse and Ideogram nodes (#10543) --- comfy_api_nodes/apinode_utils.py | 109 +------------ comfy_api_nodes/nodes_bfl.py | 43 +---- comfy_api_nodes/nodes_bytedance.py | 12 +- comfy_api_nodes/nodes_ideogram.py | 137 ++++------------ comfy_api_nodes/nodes_kling.py | 2 +- comfy_api_nodes/nodes_pixverse.py | 194 ++++++----------------- comfy_api_nodes/nodes_recraft.py | 4 +- comfy_api_nodes/nodes_runway.py | 10 +- comfy_api_nodes/nodes_vidu.py | 12 +- comfy_api_nodes/util/__init__.py | 10 +- comfy_api_nodes/util/conversions.py | 21 +++ comfy_api_nodes/util/validation_utils.py | 125 ++++++++++----- 12 files changed, 220 insertions(+), 459 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 6a72b9d1d..ecd604ff8 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,15 +1,12 @@ from __future__ import annotations import aiohttp import mimetypes -from typing import Optional, Union -from comfy.utils import common_upscale +from typing import Union from server import PromptServer -from comfy.cli_args import args import numpy as np from PIL import Image import torch -import math import base64 from io import BytesIO @@ -60,85 +57,6 @@ async def validate_and_cast_response( return torch.stack(image_tensors, dim=0) -def validate_aspect_ratio( - aspect_ratio: str, - minimum_ratio: float, - maximum_ratio: float, - minimum_ratio_str: str, - maximum_ratio_str: str, -) -> float: - """Validates and casts an aspect ratio string to a float. - - Args: - aspect_ratio: The aspect ratio string to validate. - minimum_ratio: The minimum aspect ratio. - maximum_ratio: The maximum aspect ratio. - minimum_ratio_str: The minimum aspect ratio string. - maximum_ratio_str: The maximum aspect ratio string. - - Returns: - The validated and cast aspect ratio. - - Raises: - Exception: If the aspect ratio is not valid. - """ - # get ratio values - numbers = aspect_ratio.split(":") - if len(numbers) != 2: - raise TypeError( - f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}." - ) - try: - numerator = int(numbers[0]) - denominator = int(numbers[1]) - except ValueError as exc: - raise TypeError( - f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}." - ) from exc - calculated_ratio = numerator / denominator - # if not close to minimum and maximum, check bounds - if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose( - calculated_ratio, maximum_ratio - ): - if calculated_ratio < minimum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - if calculated_ratio > maximum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - return aspect_ratio - - -async def download_url_to_bytesio( - url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> BytesIO: - """Downloads content from a URL using requests and returns it as BytesIO. - - Args: - url: The URL to download. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - BytesIO object containing the downloaded content. - """ - headers = {} - if url.startswith("/proxy/"): - url = str(args.comfy_api_base).rstrip("/") + url - auth_token = auth_kwargs.get("auth_token") - comfy_api_key = auth_kwargs.get("comfy_api_key") - if auth_token: - headers["Authorization"] = f"Bearer {auth_token}" - elif comfy_api_key: - headers["X-API-KEY"] = comfy_api_key - timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None - async with aiohttp.ClientSession(timeout=timeout_cfg) as session: - async with session.get(url, headers=headers) as resp: - resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(await resp.read()) - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str: if mime_type is None: mime_type = "application/octet-stream" return f"data:{mime_type};base64,{base64_string}" - - -def resize_mask_to_image( - mask: torch.Tensor, - image: torch.Tensor, - upscale_method="nearest-exact", - crop="disabled", - allow_gradient=True, - add_channel_dim=False, -): - """ - Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. - """ - _, H, W, _ = image.shape - mask = mask.unsqueeze(-1) - mask = mask.movedim(-1, 1) - mask = common_upscale( - mask, width=W, height=H, upscale_method=upscale_method, crop=crop - ) - mask = mask.movedim(1, -1) - if not add_channel_dim: - mask = mask.squeeze(-1) - if not allow_gradient: - mask = (mask > 0.5).float() - return mask diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index baa74fd52..1740fb377 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -5,10 +5,6 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apinode_utils import ( - resize_mask_to_image, - validate_aspect_ratio, -) from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -23,8 +19,10 @@ from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, poll_op, + resize_mask_to_image, sync_op, tensor_to_base64_string, + validate_aspect_ratio_string, validate_string, ) @@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode): Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode): @classmethod def validate_inputs(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) return True @classmethod @@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, raw=raw, image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), @@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode): Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode): seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: - aspect_ratio = validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) if input_image is None: validate_string(prompt, strip_whitespace=False) initial_response = await sync_op( diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 534af380d..caced471e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -17,7 +17,7 @@ from comfy_api_nodes.util import ( poll_op, sync_op, upload_images_to_comfyapi, - validate_image_aspect_ratio_range, + validate_image_aspect_ratio, validate_image_dimensions, validate_string, ) @@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) + validate_image_aspect_ratio(image, (1, 3), (3, 1)) source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, @@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): reference_images_urls = [] if n_input_images: for i in image: - validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) + validate_image_aspect_ratio(i, (1, 3), (3, 1)) reference_images_urls = await upload_images_to_comfyapi( cls, image, @@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( @@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 download_urls = await upload_images_to_comfyapi( cls, @@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) for image in images: validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d8fd3378b..48f94e612 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,6 +1,6 @@ from io import BytesIO from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch @@ -11,19 +11,13 @@ from comfy_api_nodes.apis import ( IdeogramV3Request, IdeogramV3EditRequest, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) - -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + bytesio_to_image_tensor, + download_url_as_bytesio, resize_mask_to_image, + sync_op, ) -from comfy_api_nodes.util import bytesio_to_image_tensor -from server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", @@ -220,7 +214,7 @@ async def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -233,19 +227,6 @@ async def download_and_process_images(image_urls): return stacked_tensors -def display_image_urls_on_node(image_urls, node_id): - if node_id and image_urls: - if len(image_urls) == 1: - PromptServer.instance.send_progress_text( - f"Generated Image URL:\n{image_urls[0]}", node_id - ) - else: - urls_text = "Generated Image URLs:\n" + "\n".join( - f"{i+1}. {url}" for i, url in enumerate(image_urls) - ) - PromptServer.instance.send_progress_text(urls_text, node_id) - - class IdeogramV1(IO.ComfyNode): @classmethod @@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, num_images=num_images, seed=seed, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, @@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode): seed=seed, aspect_ratio=final_aspect_ratio, resolution=final_resolution, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), style_type=style_type if style_type != "NONE" else None, negative_prompt=negative_prompt if negative_prompt else None, color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode): character_image=None, character_mask=None, ): - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if rendering_speed == "BALANCED": # for backward compatibility rendering_speed = "DEFAULT" @@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode): # Check if both image and mask are provided for editing mode if image is not None and mask is not None: - # Edit mode - path = "/proxy/ideogram/ideogram-v3/edit" - # Process image and mask input_tensor = image.squeeze().cpu() # Resize mask to match image dimension @@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode): if character_mask_binary: files["character_mask_binary"] = character_mask_binary - # Execute the operation for edit mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3EditRequest, - response_model=IdeogramGenerateResponse, - ), - request=edit_request, + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"), + response_model=IdeogramGenerateResponse, + data=edit_request, files=files, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) elif image is not None or mask is not None: # If only one of image or mask is provided, raise an error raise Exception("Ideogram V3 image editing requires both an image AND a mask") else: - # Generation mode - path = "/proxy/ideogram/ideogram-v3/generate" - # Create generation request gen_request = IdeogramV3Request( prompt=prompt, @@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode): if files: gen_request.style_type = "AUTO" - # Execute the operation for generation mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3Request, - response_model=IdeogramGenerateResponse, - ), - request=gen_request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=gen_request, files=files if files else None, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) - # Execute the operation and process response - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension): IdeogramV3, ] + async def comfy_entrypoint() -> IdeogramExtension: return IdeogramExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index eea65c9ac..7b23e9cf9 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ validate_image_dimensions(image, min_width=300, min_height=300) - validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) + validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) def get_video_from_response(response) -> KlingVideoResult: diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index b2b841be8..6e1686af0 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,6 @@ -from inspect import cleandoc -from typing import Optional +import torch from typing_extensions import override -from io import BytesIO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import ( PixverseIO, pixverse_templates, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + download_url_to_video_output, + poll_op, + sync_op, + tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.util import validate_string, tensor_to_bytesio -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO - -import torch -import aiohttp - AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_T2T = 52 -def get_video_url_from_response( - response: PixverseGenerationStatusResponse, -) -> Optional[str]: - if response.Resp is None or response.Resp.url is None: - return None - return str(response.Resp.url) - - -async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): - # first, upload image to Pixverse and get image id to use in actual generation call - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/image/upload", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=PixverseImageUploadResponse, - ), - request=EmptyRequest(), +async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor): + response_upload = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"), + response_model=PixverseImageUploadResponse, files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = await operation.execute() - if response_upload.Resp is None: raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") - return response_upload.Resp.img_id @@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): negative_prompt: str = None, pixverse_template: int = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + validate_string(prompt, strip_whitespace=False, min_length=1) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration if quality == PixverseQuality.res_1080p: @@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/text/generate", - method=HttpMethod.POST, - request_model=PixverseTextVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTextVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTextVideoRequest( prompt=prompt, aspect_ratio=aspect_ratio, quality=quality, @@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() - if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseImageToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): pixverse_template: int = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - img_id = await upload_image_to_pixverse(image, auth_kwargs=auth) + img_id = await upload_image_to_pixverse(cls, image) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/img/generate", - method=HttpMethod.POST, - request_model=PixverseImageVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseImageVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseImageVideoRequest( img_id=img_id, prompt=prompt, quality=quality, @@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseTransitionVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), IO.Image.Input("last_frame"), @@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt: str = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth) - last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth) + first_frame_id = await upload_image_to_pixverse(cls, first_frame) + last_frame_id = await upload_image_to_pixverse(cls, last_frame) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/transition/generate", - method=HttpMethod.POST, - request_model=PixverseTransitionVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTransitionVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTransitionVideoRequest( first_frame_img=first_frame_id, last_frame_img=last_frame_id, prompt=prompt, @@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixVerseExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index dee186cd6..e3440b946 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,9 +8,6 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apinode_utils import ( - resize_mask_to_image, -) from comfy_api_nodes.apis.recraft_api import ( RecraftColor, RecraftColorChain, @@ -28,6 +25,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, bytesio_to_image_tensor, download_url_as_bytesio, + resize_mask_to_image, sync_op, tensor_to_bytesio, validate_string, diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 0543d1d0e..2fdafbbfe 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, @@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, @@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) + validate_image_aspect_ratio(end_frame, (1, 2), (2, 1)) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( @@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode): reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) - validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(reference_image, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, reference_image, diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 0e0572f8c..7a679f0d9 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -14,9 +14,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, upload_images_to_comfyapi, - validate_aspect_ratio_closeness, - validate_image_aspect_ratio_range, + validate_image_aspect_ratio, validate_image_dimensions, + validate_images_aspect_ratio_closeness, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -114,7 +114,7 @@ async def execute_task( cls, ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), response_model=TaskStatusResponse, - status_extractor=lambda r: r.state.value, + status_extractor=lambda r: r.state, estimated_duration=estimated_duration, ) @@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) > 1: raise ValueError("Only one input image is allowed.") - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( model_name=model, prompt=prompt, @@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): if a > 7: raise ValueError("Too many images, maximum allowed is 7.") for image in images: - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( model_name=model, @@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): resolution: str, movement_amplitude: str, ) -> IO.NodeOutput: - validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( model_name=model, prompt=prompt, diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cca2b59b..bbc71363a 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -14,6 +14,7 @@ from .conversions import ( downscale_image_tensor, image_tensor_pair_to_batch, pil_to_bytesio, + resize_mask_to_image, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, @@ -34,12 +35,12 @@ from .upload_helpers import ( ) from .validation_utils import ( get_number_of_images, - validate_aspect_ratio_closeness, + validate_aspect_ratio_string, validate_audio_duration, validate_container_format_is_mp4, validate_image_aspect_ratio, - validate_image_aspect_ratio_range, validate_image_dimensions, + validate_images_aspect_ratio_closeness, validate_string, validate_video_dimensions, validate_video_duration, @@ -70,6 +71,7 @@ __all__ = [ "downscale_image_tensor", "image_tensor_pair_to_batch", "pil_to_bytesio", + "resize_mask_to_image", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", @@ -77,12 +79,12 @@ __all__ = [ "video_to_base64_string", # Validation utilities "get_number_of_images", - "validate_aspect_ratio_closeness", + "validate_aspect_ratio_string", "validate_audio_duration", "validate_container_format_is_mp4", "validate_image_aspect_ratio", - "validate_image_aspect_ratio_range", "validate_image_dimensions", + "validate_images_aspect_ratio_closeness", "validate_string", "validate_video_dimensions", "validate_video_duration", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 9f4c90c5c..b59c2bd84 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: wav = torch.cat(frames, dim=1) # [C, T] wav = _f32_pcm(wav) return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def resize_mask_to_image( + mask: torch.Tensor, + image: torch.Tensor, + upscale_method="nearest-exact", + crop="disabled", + allow_gradient=True, + add_channel_dim=False, +): + """Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.""" + _, height, width, _ = image.shape + mask = mask.unsqueeze(-1) + mask = mask.movedim(-1, 1) + mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop) + mask = mask.movedim(1, -1) + if not add_channel_dim: + mask = mask.squeeze(-1) + if not allow_gradient: + mask = (mask > 0.5).float() + return mask diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 22da05bc1..ec7006aed 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -37,63 +37,62 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_aspect_ratio: Optional[float] = None, - max_aspect_ratio: Optional[float] = None, -): - width, height = get_image_dimensions(image) - aspect_ratio = width / height - - if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}") - if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}") - - -def validate_image_aspect_ratio_range( - image: torch.Tensor, - min_ratio: tuple[float, float], # e.g. (1, 4) - max_ratio: tuple[float, float], # e.g. (4, 1) + min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) + max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: - a1, b1 = min_ratio - a2, b2 = max_ratio - if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0: - raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).") - lo, hi = (a1 / b1), (a2 / b2) - if lo > hi: - lo, hi = hi, lo - a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text + """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked.""" w, h = get_image_dimensions(image) if w <= 0 or h <= 0: raise ValueError(f"Invalid image dimensions: {w}x{h}") ar = w / h - ok = (lo < ar < hi) if strict else (lo <= ar <= hi) - if not ok: - op = "<" if strict else "≤" - raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}") + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) return ar -def validate_aspect_ratio_closeness( - start_img, - end_img, - min_rel: float, - max_rel: float, +def validate_images_aspect_ratio_closeness( + first_image: torch.Tensor, + second_image: torch.Tensor, + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, - strict: bool = False, # True => exclusive, False => inclusive -) -> None: - w1, h1 = get_image_dimensions(start_img) - w2, h2 = get_image_dimensions(end_img) + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """ + Validates that the two images' aspect ratios are 'close'. + The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1). + We require C <= limit, where limit = max(max_rel, 1.0 / min_rel). + + Returns the computed closeness factor C. + """ + w1, h1 = get_image_dimensions(first_image) + w2, h2 = get_image_dimensions(second_image) if min(w1, h1, w2, h2) <= 0: raise ValueError("Invalid image dimensions") ar1 = w1 / h1 ar2 = w2 / h2 - # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1) closeness = max(ar1, ar2) / min(ar1, ar2) - limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 + limit = max(max_rel, 1.0 / min_rel) if (closeness >= limit) if strict else (closeness > limit): - raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.") + raise ValueError( + f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, " + f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})." + ) + return closeness + + +def validate_aspect_ratio_string( + aspect_ratio: str, + min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) + max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + *, + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio.""" + ar = _parse_aspect_ratio_string(aspect_ratio) + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) + return ar def validate_video_dimensions( @@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None: container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + + +def _ratio_from_tuple(r: tuple[float, float]) -> float: + a, b = r + if a <= 0 or b <= 0: + raise ValueError(f"Ratios must be positive, got {a}:{b}.") + return a / b + + +def _assert_ratio_bounds( + ar: float, + *, + min_ratio: Optional[tuple[float, float]] = None, + max_ratio: Optional[tuple[float, float]] = None, + strict: bool = True, +) -> None: + """Validate a numeric aspect ratio against optional min/max ratio bounds.""" + lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None + hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None + + if lo is not None and hi is not None and lo > hi: + lo, hi = hi, lo # normalize order if caller swapped them + + if lo is not None: + if (ar <= lo) if strict else (ar < lo): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.") + if hi is not None: + if (ar >= hi) if strict else (ar > hi): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.") + + +def _parse_aspect_ratio_string(ar_str: str) -> float: + """Parse 'X:Y' with integer parts into a positive float ratio X/Y.""" + parts = ar_str.split(":") + if len(parts) != 2: + raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.") + try: + a = int(parts[0].strip()) + b = int(parts[1].strip()) + except ValueError as exc: + raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc + if a <= 0 or b <= 0: + raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.") + return a / b From dfac94695be95076d8028d04005a744f3ec0de8d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:22:35 +0200 Subject: [PATCH 038/231] fix img2img operation in Dall2 node (#10552) --- comfy_api_nodes/nodes_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c467e840c..b4568fc85 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC): ), files=( { - "image": img_binary, + "image": ("image.png", img_binary, "image/png"), } if img_binary else None From 513b0c46fba3bf40191d684ff81207ad935f1717 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 31 Oct 2025 07:39:02 +1000 Subject: [PATCH 039/231] Add RAM Pressure cache mode (#10454) * execution: Roll the UI cache into the outputs Currently the UI cache is parallel to the output cache with expectations of being a content superset of the output cache. At the same time the UI and output cache are maintained completely seperately, making it awkward to free the output cache content without changing the behaviour of the UI cache. There are two actual users (getters) of the UI cache. The first is the case of a direct content hit on the output cache when executing a node. This case is very naturally handled by merging the UI and outputs cache. The second case is the history JSON generation at the end of the prompt. This currently works by asking the cache for all_node_ids and then pulling the cache contents for those nodes. all_node_ids is the nodes of the dynamic prompt. So fold the UI cache into the output cache. The current UI cache setter now writes to a prompt-scope dict. When the output cache is set, just get this value from the dict and tuple up with the outputs. When generating the history, simply iterate prompt-scope dict. This prepares support for more complex caching strategies (like RAM pressure caching) where less than 1 workflow will be cached and it will be desirable to keep the UI cache and output cache in sync. * sd: Implement RAM getter for VAE * model_patcher: Implement RAM getter for ModelPatcher * sd: Implement RAM getter for CLIP * Implement RAM Pressure cache Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE. * execution: Convert the cache entry to NamedTuple As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py. --- comfy/cli_args.py | 1 + comfy/model_patcher.py | 3 ++ comfy/sd.py | 14 +++++++ comfy_execution/caching.py | 83 ++++++++++++++++++++++++++++++++++++++ comfy_execution/graph.py | 9 ++++- execution.py | 81 +++++++++++++++++++++---------------- main.py | 4 +- 7 files changed, 157 insertions(+), 38 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 001abd843..3d5bc7c90 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ed3f3f5cb..674a214ca 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -276,6 +276,9 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size + def get_ram_usage(self): + return self.model_size() + def loaded_size(self): return self.model.model_loaded_weight_memory diff --git a/comfy/sd.py b/comfy/sd.py index de4eee96e..9e5ebbf15 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -143,6 +143,9 @@ class CLIP: n.apply_hooks_to_conds = self.apply_hooks_to_conds return n + def get_ram_usage(self): + return self.patcher.get_ram_usage() + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -293,6 +296,7 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False + self.size = None self.downscale_index_formula = None self.upscale_index_formula = None @@ -595,6 +599,16 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + self.model_size() + + def model_size(self): + if self.size is not None: + return self.size + self.size = comfy.model_management.module_size(self.first_stage_model) + return self.size + + def get_ram_usage(self): + return self.model_size() def throw_exception_if_invalid(self): if self.first_stage_model is None: diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 566bc3f9c..b498f43e7 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,9 @@ +import bisect +import gc import itertools +import psutil +import time +import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt from abc import ABC, abstractmethod @@ -188,6 +193,9 @@ class BasicCache: self._clean_cache() self._clean_subcaches() + def poll(self, **kwargs): + pass + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -276,6 +284,9 @@ class NullCache: def clean_unused(self): pass + def poll(self, **kwargs): + pass + def get(self, node_id): return None @@ -336,3 +347,75 @@ class LRUCache(BasicCache): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. + +RAM_CACHE_HYSTERESIS = 1.1 + +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK + +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 + +#Exponential bias towards evicting older workflows so garbage will be taken out +#in constantly changing setups. + +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 + +class RAMPressureCache(LRUCache): + + def __init__(self, key_class): + super().__init__(key_class, 0) + self.timestamps = {} + + def clean_unused(self): + self._clean_subcaches() + + def set(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set(node_id, value) + + def get(self, node_id): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + return super().get(node_id) + + def poll(self, ram_headroom): + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > ram_headroom: + return + gc.collect() + if _ram_gb() > ram_headroom: + return + + clean_list = [] + + for key, (outputs, _), in self.cache.items(): + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) + + ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE + def scan_list_for_ram_usage(outputs): + nonlocal ram_usage + for output in outputs: + if isinstance(output, list): + scan_list_for_ram_usage(output) + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': + #score Tensors at a 50% discount for RAM usage as they are likely to + #be high value intermediates + ram_usage += (output.numel() * output.element_size()) * 0.5 + elif hasattr(output, "get_ram_usage"): + ram_usage += output.get_ram_usage() + scan_list_for_ram_usage(outputs) + + oom_score *= ram_usage + #In the case where we have no information on the node ram usage at all, + #break OOM score ties on the last touch timestamp (pure LRU) + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + _, _, key = clean_list.pop() + del self.cache[key] + gc.collect() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 341c9735d..0d811e354 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) - def get_output_cache(self, from_node_id, to_node_id): + def get_cache(self, from_node_id, to_node_id): if not to_node_id in self.execution_cache: return None - return self.execution_cache[to_node_id].get(from_node_id) + value = self.execution_cache[to_node_id].get(from_node_id) + if value is None: + return None + #Write back to the main cache on touch. + self.output_cache.set(from_node_id, value) + return value def cache_update(self, node_id, value): if node_id in self.execution_cache_listeners: diff --git a/execution.py b/execution.py index 20e106213..17c77beab 100644 --- a/execution.py +++ b/execution.py @@ -21,6 +21,7 @@ from comfy_execution.caching import ( NullCache, HierarchicalCache, LRUCache, + RAMPressureCache, ) from comfy_execution.graph import ( DynamicPrompt, @@ -88,49 +89,56 @@ class IsChangedCache: return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 NONE = 2 + RAM_PRESSURE = 3 class CacheSet: - def __init__(self, cache_type=None, cache_size=None): + def __init__(self, cache_type=None, cache_args={}): if cache_type == CacheType.NONE: self.init_null_cache() logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.RAM_PRESSURE: + cache_ram = cache_args.get("ram", 16.0) + self.init_ram_cache(cache_ram) + logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: - if cache_size is None: - cache_size = 0 + cache_size = cache_args.get("lru", 0) self.init_lru_cache(cache_size) logging.info("Using LRU cache") else: self.init_classic_cache() - self.all = [self.outputs, self.ui, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + def init_ram_cache(self, min_headroom): + self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): self.outputs = NullCache() - #The UI cache is expected to be iterable at the end of each workflow - #so it must cache at least a full workflow. Use Heirachical - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = NullCache() def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), - "ui": self.ui.recursive_debug_dump(), } return result @@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -393,7 +401,7 @@ def format_value(x): else: return str(x) -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.get(unique_id) is not None: + cached = caches.outputs.get(unique_id) + if cached is not None: if server.client_id is not None: - cached_output = caches.ui.get(unique_id) or {} - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + cached_ui = cached.ui or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) + execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: - caches.ui.set(unique_id, { + ui_outputs[unique_id] = { "meta": { "node_id": unique_id, "display_node": display_node_id, @@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "real_node_id": real_node_id, }, "output": output_ui - }) + } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: @@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) - execution_list.cache_update(unique_id, output_data) + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, cache_type=False, cache_size=None): - self.cache_size = cache_size + def __init__(self, server, cache_type=False, cache_args=None): + self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True @@ -682,6 +694,7 @@ class PromptExecutor: broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -695,7 +708,7 @@ class PromptExecutor: break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -704,18 +717,16 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} meta_outputs = {} - all_node_ids = self.caches.ui.all_node_ids() - for node_id in all_node_ids: - ui_info = self.caches.ui.get(node_id) - if ui_info is not None: - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, diff --git a/main.py b/main.py index 8d466d2eb..e1b0f1620 100644 --- a/main.py +++ b/main.py @@ -172,10 +172,12 @@ def prompt_worker(q, server_instance): cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU + elif args.cache_ram > 0: + cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 From 614cf9805e1056216487a2d1b1a07206d77f87e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:11:38 -0700 Subject: [PATCH 040/231] Add a ScaleROPE node. Currently only works on WAN models. (#10559) --- comfy/ldm/wan/model.py | 20 ++++++++++++---- comfy/model_patcher.py | 13 +++++++++++ comfy_extras/nodes_rope.py | 47 ++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 comfy_extras/nodes_rope.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 90c347d3d..77876c2e7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -588,7 +588,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -601,10 +601,22 @@ class WanModel(torch.nn.Module): if steps_w is None: steps_w = w_len + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = t_len * rope_options.get("scale_t", 1.0) + h_len = h_len * rope_options.get("scale_y", 1.0) + w_len = w_len * rope_options.get("scale_x", 1.0) + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) @@ -630,7 +642,7 @@ class WanModel(torch.nn.Module): if self.ref_conv is not None and "reference_latent" in kwargs: t_len += 1 - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 674a214ca..3e8799983 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -454,6 +454,19 @@ class ModelPatcher: def set_model_post_input_patch(self, patch): self.set_model_patch(patch, "post_input") + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): + rope_options = self.model_options["transformer_options"].get("rope_options", {}) + rope_options["scale_x"] = scale_x + rope_options["scale_y"] = scale_y + rope_options["scale_t"] = scale_t + + rope_options["shift_x"] = shift_x + rope_options["shift_y"] = shift_y + rope_options["shift_t"] = shift_t + + self.model_options["transformer_options"]["rope_options"] = rope_options + + def add_object_patch(self, name, obj): self.object_patches[name] = obj diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py new file mode 100644 index 000000000..d1feb031e --- /dev/null +++ b/comfy_extras/nodes_rope.py @@ -0,0 +1,47 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class ScaleROPE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ScaleROPE", + category="advanced/model_patches", + description="Scale and shift the ROPE of the model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1), + + + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput: + m = model.clone() + m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) + return io.NodeOutput(m) + + +class RopeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ScaleROPE + ] + + +async def comfy_entrypoint() -> RopeExtension: + return RopeExtension() diff --git a/nodes.py b/nodes.py index 12e365ca9..5689f6fe1 100644 --- a/nodes.py +++ b/nodes.py @@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes(): "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", + "nodes_rope.py", ] import_failed = [] From 27d1bd882925e3bbdffb405cea098ac52bb20ac5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:51:58 -0700 Subject: [PATCH 041/231] Fix rope scaling. (#10560) --- comfy/ldm/wan/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 77876c2e7..5ec1511ce 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -605,9 +605,9 @@ class WanModel(torch.nn.Module): w_start = 0 rope_options = transformer_options.get("rope_options", None) if rope_options is not None: - t_len = t_len * rope_options.get("scale_t", 1.0) - h_len = h_len * rope_options.get("scale_y", 1.0) - w_len = w_len * rope_options.get("scale_x", 1.0) + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 t_start += rope_options.get("shift_t", 0.0) h_start += rope_options.get("shift_y", 0.0) From 7f374e42c833c69c71605507b90f79cc26d14a71 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:41:40 -0700 Subject: [PATCH 042/231] ScaleROPE now works on Lumina models. (#10578) --- comfy/ldm/lumina/model.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f87d98ac0..b4494a51d 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -522,7 +522,7 @@ class NextDiT(nn.Module): max_cap_len = max(l_effective_cap_len) max_img_len = max(l_effective_img_len) - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) for i in range(bsz): cap_len = l_effective_cap_len[i] @@ -531,10 +531,22 @@ class NextDiT(nn.Module): H_tokens, W_tokens = H // pH, W // pW assert H_tokens * W_tokens == img_len - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids From c58c13b2bad6df0de93cc0cf107e96522a3cb5b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 31 Oct 2025 21:25:17 -0700 Subject: [PATCH 043/231] Fix torch compile regression on fp8 ops. (#10580) --- comfy/ops.py | 24 +++++------------ comfy/quant_ops.py | 27 +++++++++++++++---- .../comfy_quant/test_mixed_precision.py | 8 +++--- tests-unit/comfy_quant/test_quant_registry.py | 20 +++++++------- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 18f6b804b..279f6b1a7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -401,15 +401,9 @@ def fp8_linear(self, input): if dtype not in [torch.float8_e4m3fn]: return None - tensor_2d = False - if len(input.shape) == 2: - tensor_2d = True - input = input.unsqueeze(1) - - input_shape = input.shape input_dtype = input.dtype - if len(input.shape) == 3: + if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight @@ -422,24 +416,20 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) else: scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) - - if tensor_2d: - return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) + return o return None @@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, TensorCoreFP8Layout +from .quant_ops import QuantizedTensor QUANT_FORMAT_MIXINS = { "float8_e4m3fn": { "dtype": torch.float8_e4m3fn, - "layout_type": TensorCoreFP8Layout, + "layout_type": "TensorCoreFP8Layout", "parameters": { "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c822fe53c..873f173ed 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() @@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor): @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) def dequantize(self) -> torch.Tensor: - return self._layout_type.dequantize(self._qdata, **self._layout_params) + return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout): return qtensor._qdata, qtensor._layout_params['scale'] -@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +LAYOUTS = { + "TensorCoreFP8Layout": TensorCoreFP8Layout, +} + + +@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] @@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs): 'scale': output_scale, 'orig_dtype': input_tensor._layout_params['orig_dtype'] } - return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) else: return output @@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs): input_tensor = input_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") +@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") +def fp8_func(func, args, kwargs): + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + ar = list(args) + ar[0] = plain_input + return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) + return func(*args, **kwargs) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 267bc177b..f8d1fd04e 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -14,7 +14,7 @@ if not has_gpu(): args.cpu = True from comfy import ops -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import QuantizedTensor class SimpleModel(torch.nn.Module): @@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) @@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 477811029..9cb54ede8 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") def test_dequantize(self): """Test explicit dequantization""" @@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) dequantized = qt.dequantize() self.assertEqual(dequantized.dtype, torch.float32) @@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase): qt = QuantizedTensor.from_float( float_tensor, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn ) @@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Detach should return a new QuantizedTensor qt_detached = qt.detach() self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Clone should return a new QuantizedTensor qt_cloned = qt.clone() self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) @@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') @@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase): scale = torch.tensor(1.0) a_q = QuantizedTensor.from_float( a_fp32, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn ) From 5f109fe6a06a3462b31a066bcfd650de67d66102 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:13:39 +0200 Subject: [PATCH 044/231] added 12s-20s as available output durations for the LTXV API nodes (#10570) --- comfy_api_nodes/nodes_ltxv.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index e6ad6e27a..0b757a62b 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode): multiline=True, default="", ), - IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), IO.Combo.Input( "resolution", options=[ @@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode): generate_audio: bool = False, ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) response = await sync_op_raw( cls, ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), @@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode): multiline=True, default="", ), - IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), IO.Combo.Input( "resolution", options=[ @@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode): generate_audio: bool = False, ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") response = await sync_op_raw( From 20182a393f43ab1fdf798f8da6aac0ef6116e7e6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:14:06 +0200 Subject: [PATCH 045/231] convert StabilityAI to use new API client (#10582) --- comfy_api_nodes/nodes_stability.py | 179 ++++++++--------------------- comfy_api_nodes/util/client.py | 4 +- 2 files changed, 48 insertions(+), 135 deletions(-) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 783666ddf..bb7ceed78 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import ( StabilityAudioInpaintRequest, StabilityAudioResponse, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) from comfy_api_nodes.util import ( validate_audio_duration, validate_string, @@ -34,6 +27,9 @@ from comfy_api_nodes.util import ( bytesio_to_image_tensor, tensor_to_bytesio, audio_bytes_to_audio_input, + sync_op, + poll_op, + ApiEndpoint, ) import torch @@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/ultra", - method=HttpMethod.POST, - request_model=StabilityStableUltraRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStableUltraRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStableUltraRequest( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/sd3", - method=HttpMethod.POST, - request_model=StabilityStable3_5Request, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStable3_5Request( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStable3_5Request( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/conservative", - method=HttpMethod.POST, - request_model=StabilityUpscaleConservativeRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityUpscaleConservativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityUpscaleConservativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/creative", - method=HttpMethod.POST, - request_model=StabilityUpscaleCreativeRequest, - response_model=StabilityAsyncResponse, - ), - request=StabilityUpscaleCreativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"), + response_model=StabilityAsyncResponse, + data=StabilityUpscaleCreativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/stability/v2beta/results/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=StabilityResultsGetResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"), + response_model=StabilityResultsGetResponse, poll_interval=3, - completed_statuses=[StabilityPollStatus.finished], - failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=auth, - node_id=cls.hidden.unique_id, ) - response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/fast", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=StabilityStableUltraResponse, - ), - request=EmptyRequest(), + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"), + response_model=StabilityStableUltraResponse, files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") @@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode): async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: validate_string(prompt, max_length=10000) payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", - method=HttpMethod.POST, - request_model=StabilityTextToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode): payload = StabilityAudioToAudioRequest( prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", - method=HttpMethod.POST, - request_model=StabilityAudioToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode): mask_start=mask_start, mask_end=mask_end, ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", - method=HttpMethod.POST, - request_model=StabilityAudioInpaintRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9ae512fe5..65bb35f0f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -77,7 +77,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] @@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = {"Accept": "*/*"} + payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: From 44869ff786dc90b36172fd766c9a110e4c40c04b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 1 Nov 2025 14:25:59 -0700 Subject: [PATCH 046/231] Fix issue with pinned memory. (#10597) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3e8799983..5a31a8734 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -298,6 +298,7 @@ class ModelPatcher: n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self + n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights From 135fa49ec23320834f774cf3def9e51ad3773f86 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 2 Nov 2025 08:48:53 +1000 Subject: [PATCH 047/231] Small speed improvements to --async-offload (#10593) * ops: dont take an offload stream if you dont need one * ops: prioritize mem transfer The async offload streams reason for existence is to transfer from RAM to GPU. The post processing compute steps are a bonus on the side stream, but if the compute stream is running a long kernel, it can stall the side stream, as it wait to type-cast the bias before transferring the weight. So do a pure xfer of the weight straight up, then do everything bias, then go back to fix the weight type and do weight patches. --- comfy/ops.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 279f6b1a7..0c8f23848 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device - if offloadable: + if offloadable and (device != s.weight.device or + (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None @@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: wf_context = contextlib.nullcontext() - bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) - if s.bias is not None: - has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: + weight_has_function = len(s.weight_function) > 0 + bias_has_function = len(s.bias_function) > 0 + + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + + bias = None + if s.bias is not None: + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + + if bias_has_function: with wf_context: for f in s.bias_function: bias = f(bias) - has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: + weight = weight.to(dtype=dtype) + if weight_has_function: with wf_context: for f in s.weight_function: weight = f(weight) From 97ff9fae7e728cffdfc3aee6d72aa1e0d0b78702 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 2 Nov 2025 10:14:04 -0800 Subject: [PATCH 048/231] Clarify help text for --fast argument (#10609) Updated help text for the --fast argument to clarify potential risks. --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3d5bc7c90..3947e62a8 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -147,7 +147,7 @@ class PerformanceFeature(enum.Enum): AutoTune = "autotune" PinnedMem = "pinned_memory" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") From 6d6a18b0b730351416556eeb0990ab219ffec189 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:04:56 +0200 Subject: [PATCH 049/231] fix(api-nodes-cloud): stop using sub-folder and absolute path for output of Rodin3D nodes (#10556) --- comfy_api_nodes/nodes_rodin.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index cf2172bd6..ad4029236 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = async def download_files(url_list, task_uuid): - save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") + result_folder_name = f"Rodin3D_{task_uuid}" + save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None async with aiohttp.ClientSession() as session: for i in url_list.list: - url = i.url - file_name = i.name - file_path = os.path.join(save_path, file_name) + file_path = os.path.join(save_path, i.name) if file_path.endswith(".glb"): - model_file_path = file_path + model_file_path = os.path.join(result_folder_name, i.name) logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) max_retries = 5 for attempt in range(max_retries): try: - async with session.get(url) as resp: + async with session.get(i.url) as resp: resp.raise_for_status() with open(file_path, "wb") as f: async for chunk in resp.content.iter_chunked(32 * 1024): From 88df172790b8ed7b2e6ea7c0f0bd63ca3553921b Mon Sep 17 00:00:00 2001 From: EverNebula Date: Mon, 3 Nov 2025 16:16:40 +0800 Subject: [PATCH 050/231] fix(caching): treat bytes as hashable (#10567) --- comfy_execution/caching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index b498f43e7..e077f78b0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -53,7 +53,7 @@ class Unhashable: def to_hashable(obj): # So that we don't infinitely recurse since frozenset and tuples # are Sequences. - if isinstance(obj, (int, float, str, bool, type(None))): + if isinstance(obj, (int, float, str, bool, bytes, type(None))): return obj elif isinstance(obj, Mapping): return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) From 1f3f7a2823017ad193e060d9221fb6a52f2eba3a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:21:47 +0200 Subject: [PATCH 051/231] convert nodes_hypernetwork.py to V3 schema (#10583) --- comfy_extras/nodes_hypernetwork.py | 48 ++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 665632292..2a6a87a81 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -2,6 +2,9 @@ import comfy.utils import folder_paths import torch import logging +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override + def load_hypernetwork_patch(path, strength): sd = comfy.utils.load_torch_file(path, safe_load=True) @@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength): return hypernetwork_patch(out, strength) -class HypernetworkLoader: +class HypernetworkLoader(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), - "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_hypernetwork" + def define_schema(cls): + return IO.Schema( + node_id="HypernetworkLoader", + category="loaders", + inputs=[ + IO.Model.Input("model"), + IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), + IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "loaders" - - def load_hypernetwork(self, model, hypernetwork_name, strength): + @classmethod + def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput: hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) model_hypernetwork = model.clone() patch = load_hypernetwork_patch(hypernetwork_path, strength) if patch is not None: model_hypernetwork.set_model_attn1_patch(patch) model_hypernetwork.set_model_attn2_patch(patch) - return (model_hypernetwork,) + return IO.NodeOutput(model_hypernetwork) -NODE_CLASS_MAPPINGS = { - "HypernetworkLoader": HypernetworkLoader -} + load_hypernetwork = execute # TODO: remove + + +class HyperNetworkExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HypernetworkLoader, + ] + + +async def comfy_entrypoint() -> HyperNetworkExtension: + return HyperNetworkExtension() From e617cddf244e4b789afba4b4ece01661b12fdde5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:28:13 +0200 Subject: [PATCH 052/231] convert nodes_openai.py to V3 schema (#10604) --- comfy_api_nodes/apinode_utils.py | 73 -- comfy_api_nodes/nodes_openai.py | 1094 +++++++++++---------------- comfy_api_nodes/util/__init__.py | 4 + comfy_api_nodes/util/conversions.py | 19 +- 4 files changed, 482 insertions(+), 708 deletions(-) delete mode 100644 comfy_api_nodes/apinode_utils.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py deleted file mode 100644 index ecd604ff8..000000000 --- a/comfy_api_nodes/apinode_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations -import aiohttp -import mimetypes -from typing import Union -from server import PromptServer - -import numpy as np -from PIL import Image -import torch -import base64 -from io import BytesIO - - -async def validate_and_cast_response( - response, timeout: int = None, node_id: Union[str, None] = None -) -> torch.Tensor: - """Validates and casts a response to a torch.Tensor. - - Args: - response: The response to validate and cast. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - ValueError: If the response is not valid. - """ - # validate raw JSON response - data = response.data - if not data or len(data) == 0: - raise ValueError("No images returned from API endpoint") - - # Initialize list to store image tensors - image_tensors: list[torch.Tensor] = [] - - # Process each image in the data array - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: - for img_data in data: - img_bytes: bytes - if img_data.b64_json: - img_bytes = base64.b64decode(img_data.b64_json) - elif img_data.url: - if node_id: - PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id) - async with session.get(img_data.url) as resp: - if resp.status != 200: - raise ValueError("Failed to download generated image") - img_bytes = await resp.read() - else: - raise ValueError("Invalid image payload – neither URL nor base64 data present.") - - pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA") - arr = np.asarray(pil_img).astype(np.float32) / 255.0 - image_tensors.append(torch.from_numpy(arr)) - - return torch.stack(image_tensors, dim=0) - - -def text_filepath_to_base64_string(filepath: str) -> str: - """Converts a text file to a base64 string.""" - with open(filepath, "rb") as f: - file_content = f.read() - return base64.b64encode(file_content).decode("utf-8") - - -def text_filepath_to_data_uri(filepath: str) -> str: - """Converts a text file to a data URI.""" - base64_string = text_filepath_to_base64_string(filepath) - mime_type, _ = mimetypes.guess_type(filepath) - if mime_type is None: - mime_type = "application/octet-stream" - return f"data:{mime_type};base64,{base64_string}" diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index b4568fc85..acf35d276 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,18 +1,19 @@ -import io -from typing import TypedDict, Optional +from io import BytesIO +from typing import Optional, Union import json import os import time -import re import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict from server import PromptServer import folder_paths +import base64 +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override from comfy_api_nodes.apis import ( @@ -23,7 +24,6 @@ from comfy_api_nodes.apis import ( OpenAIResponse, CreateModelResponseProperties, Item, - Includable, OutputContent, InputImageContent, Detail, @@ -34,41 +34,22 @@ from comfy_api_nodes.apis import ( InputFileContent, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + downscale_image_tensor, + download_url_to_bytesio, + validate_string, + tensor_to_base64_string, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) - -from comfy_api_nodes.apinode_utils import ( - validate_and_cast_response, + sync_op, + poll_op, text_filepath_to_data_uri, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api_nodes.util import downscale_image_tensor, validate_string, tensor_to_base64_string RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" -class HistoryEntry(TypedDict): - """Type definition for a single history entry in the chat.""" - - prompt: str - response: str - response_id: str - timestamp: float - - -class ChatHistory(TypedDict): - """Type definition for the chat history dictionary.""" - - __annotations__: dict[str, list[HistoryEntry]] - - class SupportedOpenAIModel(str, Enum): o4_mini = "o4-mini" o1 = "o1" @@ -83,98 +64,123 @@ class SupportedOpenAIModel(str, Enum): gpt_5_nano = "gpt-5-nano" -class OpenAIDalle2(ComfyNodeABC): +async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: + """Validates and casts a response to a torch.Tensor. + + Args: + response: The response to validate and cast. + timeout: Request timeout in seconds. Defaults to None (no timeout). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + ValueError: If the response is not valid. + """ + # validate raw JSON response + data = response.data + if not data or len(data) == 0: + raise ValueError("No images returned from API endpoint") + + # Initialize list to store image tensors + image_tensors: list[torch.Tensor] = [] + + # Process each image in the data array + for img_data in data: + if img_data.b64_json: + img_io = BytesIO(base64.b64decode(img_data.b64_json)) + elif img_data.url: + img_io = BytesIO() + await download_url_to_bytesio(img_data.url, img_io, timeout=timeout) + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") + + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + + return torch.stack(image_tensors, dim=0) + + +class OpenAIDalle2(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 2 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle2", + display_name="OpenAI DALL·E 2", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2**31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["256x256", "512x512", "1024x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["256x256", "512x512", "1024x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + async def execute( + cls, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-2" path = "/proxy/openai/images/generations" @@ -200,7 +206,7 @@ class OpenAIDalle2(ComfyNodeABC): image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr # .getvalue() @@ -208,15 +214,11 @@ class OpenAIDalle2(ComfyNodeABC): elif image is not None or mask is not None: raise Exception("Dall-E 2 image editing requires an image AND a mask") - # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, n=n, @@ -231,109 +233,92 @@ class OpenAIDalle2(ComfyNodeABC): else None ), content_type=content_type, - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIDalle3(ComfyNodeABC): +class OpenAIDalle3(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 3 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle3", + display_name="OpenAI DALL·E 3", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="standard", + options=["standard", "hd"], + tooltip="Image quality", + optional=True, + ), + IO.Combo.Input( + "style", + default="natural", + options=["natural", "vivid"], + tooltip="Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["1024x1024", "1024x1792", "1792x1024"], + tooltip="Image size", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["standard", "hd"], - "default": "standard", - "tooltip": "Image quality", - }, - ), - "style": ( - IO.COMBO, - { - "options": ["natural", "vivid"], - "default": "natural", - "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["1024x1024", "1024x1792", "1792x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + async def execute( + cls, prompt, seed=0, style="natural", quality="standard", size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-3" # build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/images/generations", - method=HttpMethod.POST, - request_model=OpenAIImageGenerationRequest, - response_model=OpenAIImageGenerationResponse, - ), - request=OpenAIImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( model=model, prompt=prompt, quality=quality, @@ -341,114 +326,97 @@ class OpenAIDalle3(ComfyNodeABC): style=style, seed=seed, ), - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIGPTImage1(ComfyNodeABC): +class OpenAIGPTImage1(IO.ComfyNode): """ Generates images synchronously via OpenAI's GPT Image 1 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIGPTImage1", + display_name="OpenAI GPT Image 1", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for GPT Image 1", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="low", + options=["low", "medium", "high"], + tooltip="Image quality, affects cost and generation time.", + optional=True, + ), + IO.Combo.Input( + "background", + default="opaque", + options=["opaque", "transparent"], + tooltip="Return image with or without background", + optional=True, + ), + IO.Combo.Input( + "size", + default="auto", + options=["auto", "1024x1024", "1024x1536", "1536x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for GPT Image 1", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["low", "medium", "high"], - "default": "low", - "tooltip": "Image quality, affects cost and generation time.", - }, - ), - "background": ( - IO.COMBO, - { - "options": ["opaque", "transparent"], - "default": "opaque", - "tooltip": "Return image with or without background", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], - "default": "auto", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + async def execute( + cls, prompt, seed=0, quality="low", @@ -457,9 +425,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask=None, n=1, size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "gpt-image-1" path = "/proxy/openai/images/generations" @@ -480,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC): image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) @@ -504,20 +470,17 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = io.BytesIO() + mask_img_byte_arr = BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, quality=quality, @@ -528,127 +491,70 @@ class OpenAIGPTImage1(ComfyNodeABC): ), files=files if files else None, content_type=content_type, - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAITextNode(ComfyNodeABC): - """ - Base class for OpenAI text generation nodes. - """ - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "api_call" - CATEGORY = "api node/text/OpenAI" - API_NODE = True - - -class OpenAIChatNode(OpenAITextNode): +class OpenAIChatNode(IO.ComfyNode): """ Node to generate text responses from an OpenAI model. """ - def __init__(self) -> None: - """Initialize the chat node with a new session ID and empty history.""" - self.current_session_id: str = str(uuid.uuid4()) - self.history: dict[str, list[HistoryEntry]] = {} - self.previous_response_id: Optional[str] = None + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatNode", + display_name="OpenAI ChatGPT", + category="api node/text/OpenAI", + description="Generate text responses from an OpenAI model.", + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text inputs to the model, used to generate a response.", + ), + IO.Boolean.Input( + "persist_context", + default=False, + tooltip="This parameter is deprecated and has no effect.", + ), + IO.Combo.Input( + "model", + options=SupportedOpenAIModel, + tooltip="The model used to generate the response", + ), + IO.Image.Input( + "images", + tooltip="Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + optional=True, + ), + IO.Custom("OPENAI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", + ), + IO.Custom("OPENAI_CHAT_CONFIG").Input( + "advanced_options", + optional=True, + tooltip="Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response.", - }, - ), - "persist_context": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Persist chat context between calls (multi-turn conversation)", - }, - ), - "model": model_field_to_node_input( - IO.COMBO, - OpenAICreateResponse, - "model", - enum_type=SupportedOpenAIModel, - ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, - ), - "files": ( - "OPENAI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", - }, - ), - "advanced_options": ( - "OPENAI_CHAT_CONFIG", - { - "default": None, - "tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate text responses from an OpenAI model." - - async def get_result_response( - self, - response_id: str, - include: Optional[list[Includable]] = None, - auth_kwargs: Optional[dict[str, str]] = None, - ) -> OpenAIResponse: - """ - Retrieve a model response with the given ID from the OpenAI API. - - Args: - response_id (str): The ID of the response to retrieve. - include (Optional[List[Includable]]): Additional fields to include - in the response. See the `include` parameter for Response - creation above for more information. - - """ - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{RESPONSES_ENDPOINT}/{response_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=OpenAIResponse, - query_params={"include": include}, - ), - completed_statuses=["completed"], - failed_statuses=["failed"], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - ).execute() - def get_message_content_from_response( - self, response: OpenAIResponse + cls, response: OpenAIResponse ) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: @@ -656,8 +562,9 @@ class OpenAIChatNode(OpenAITextNode): return output.root.content raise TypeError("No output message found in response") + @classmethod def get_text_from_message_content( - self, message_content: list[OutputContent] + cls, message_content: list[OutputContent] ) -> str: """Extract text content from message content.""" for content_item in message_content: @@ -665,58 +572,9 @@ class OpenAIChatNode(OpenAITextNode): return str(content_item.root.text) return "No text output found in response" - def get_history_text(self, session_id: str) -> str: - """Convert the entire history for a given session to JSON string.""" - return json.dumps(self.history[session_id]) - - def display_history_on_node(self, session_id: str, node_id: str) -> None: - """Display formatted chat history on the node UI.""" - render_spec = { - "node_id": node_id, - "component": "ChatHistoryWidget", - "props": { - "history": self.get_history_text(session_id), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - - def add_to_history( - self, session_id: str, prompt: str, output_text: str, response_id: str - ) -> None: - """Add a new entry to the chat history.""" - if session_id not in self.history: - self.history[session_id] = [] - self.history[session_id].append( - { - "prompt": prompt, - "response": output_text, - "response_id": response_id, - "timestamp": time.time(), - } - ) - - def parse_output_text_from_response(self, response: OpenAIResponse) -> str: - """Extract text output from the API response.""" - message_contents = self.get_message_content_from_response(response) - return self.get_text_from_message_content(message_contents) - - def generate_new_session_id(self) -> str: - """Generate a new unique session ID.""" - return str(uuid.uuid4()) - - def get_session_id(self, persist_context: bool) -> str: - """Get the current or generate a new session ID based on context persistence.""" - return ( - self.current_session_id - if persist_context - else self.generate_new_session_id() - ) - + @classmethod def tensor_to_input_image_content( - self, image: torch.Tensor, detail_level: Detail = "auto" + cls, image: torch.Tensor, detail_level: Detail = "auto" ) -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( @@ -725,21 +583,27 @@ class OpenAIChatNode(OpenAITextNode): type="input_image", ) + @classmethod def create_input_message_contents( - self, + cls, prompt: str, image: Optional[torch.Tensor] = None, files: Optional[list[InputFileContent]] = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent] = [ + content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: for i in range(image.shape[0]): content_list.append( - self.tensor_to_input_image_content(image[i].unsqueeze(0)) + InputImageContent( + detail="auto", + image_url=f"data:image/png;base64,{tensor_to_base64_string(image[i].unsqueeze(0))}", + type="input_image", + ) ) + if files is not None: content_list.extend(files) @@ -747,80 +611,28 @@ class OpenAIChatNode(OpenAITextNode): root=content_list, ) - def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]: - """Extract response ID from prompt if it exists.""" - parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt) - return parsed_id.group(1) if parsed_id else None - - def strip_response_tag_from_prompt(self, prompt: str) -> str: - """Remove the response ID tag from the prompt.""" - return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip()) - - def delete_history_after_response_id( - self, new_start_id: str, session_id: str - ) -> None: - """Delete history entries after a specific response ID.""" - if session_id not in self.history: - return - - new_history = [] - i = 0 - while ( - i < len(self.history[session_id]) - and self.history[session_id][i]["response_id"] != new_start_id - ): - new_history.append(self.history[session_id][i]) - i += 1 - - # Since it's the new starting point (not the response being edited), we include it as well - if i < len(self.history[session_id]): - new_history.append(self.history[session_id][i]) - - self.history[session_id] = new_history - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - persist_context: bool, - model: SupportedOpenAIModel, - unique_id: Optional[str] = None, + persist_context: bool = False, + model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: Optional[torch.Tensor] = None, files: Optional[list[InputFileContent]] = None, advanced_options: Optional[CreateModelResponseProperties] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - session_id = self.get_session_id(persist_context) - response_id_override = self.parse_response_id_from_prompt(prompt) - if response_id_override: - is_starting_from_beginning = response_id_override == "start" - if is_starting_from_beginning: - self.history[session_id] = [] - previous_response_id = None - else: - previous_response_id = response_id_override - self.delete_history_after_response_id(response_id_override, session_id) - prompt = self.strip_response_tag_from_prompt(prompt) - elif persist_context: - previous_response_id = self.previous_response_id - else: - previous_response_id = None - # Create response - create_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=RESPONSES_ENDPOINT, - method=HttpMethod.POST, - request_model=OpenAICreateResponse, - response_model=OpenAIResponse, - ), - request=OpenAICreateResponse( + create_response = await sync_op( + cls, + ApiEndpoint(path=RESPONSES_ENDPOINT, method="POST"), + response_model=OpenAIResponse, + data=OpenAICreateResponse( input=[ Item( root=InputMessage( - content=self.create_input_message_contents( + content=cls.create_input_message_contents( prompt, images, files ), role="user", @@ -830,36 +642,57 @@ class OpenAIChatNode(OpenAITextNode): store=True, stream=False, model=model, - previous_response_id=previous_response_id, + previous_response_id=None, **( advanced_options.model_dump(exclude_none=True) if advanced_options else {} ), ), - auth_kwargs=kwargs, - ).execute() + ) response_id = create_response.id # Get result output - result_response = await self.get_result_response(response_id, auth_kwargs=kwargs) - output_text = self.parse_output_text_from_response(result_response) + result_response = await poll_op( + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"] + ) + output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) # Update history - self.add_to_history(session_id, prompt, output_text, response_id) - self.display_history_on_node(session_id, unique_id) - self.previous_response_id = response_id - - return (output_text,) + render_spec = { + "node_id": cls.hidden.unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) + return IO.NodeOutput(output_text) -class OpenAIInputFiles(ComfyNodeABC): +class OpenAIInputFiles(IO.ComfyNode): """ Loads and formats input files for OpenAI API. """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://platform.openai.com/docs/guides/pdf-files?api-mode=responses @@ -874,97 +707,92 @@ class OpenAIInputFiles(ComfyNodeABC): ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="OpenAIInputFiles", + display_name="OpenAI ChatGPT Input Files", + category="api node/text/OpenAI", + description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "OPENAI_INPUT_FILES": ( + IO.Custom("OPENAI_INPUT_FILES").Input( "OPENAI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", + optional=True, ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_INPUT_FILES").Output(), + ], + ) - DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes." - RETURN_TYPES = ("OPENAI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/OpenAI" - - def create_input_file_content(self, file_path: str) -> InputFileContent: + @classmethod + def create_input_file_content(cls, file_path: str) -> InputFileContent: return InputFileContent( file_data=text_filepath_to_data_uri(file_path), filename=os.path.basename(file_path), type="input_file", ) - def prepare_files( - self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = [] - ) -> tuple[list[InputFileContent]]: + @classmethod + def execute(cls, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []) -> IO.NodeOutput: """ Loads and formats input files for OpenAI API. """ file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_input_file_content(file_path) + input_file_content = cls.create_input_file_content(file_path) files = [input_file_content] + OPENAI_INPUT_FILES - return (files,) + return IO.NodeOutput(files) -class OpenAIChatConfig(ComfyNodeABC): +class OpenAIChatConfig(IO.ComfyNode): """Allows setting additional configuration for the OpenAI Chat Node.""" - RETURN_TYPES = ("OPENAI_CHAT_CONFIG",) - FUNCTION = "configure" - DESCRIPTION = ( - "Allows specifying advanced configuration options for the OpenAI Chat Nodes." - ) - CATEGORY = "api node/text/OpenAI" - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "truncation": ( - IO.COMBO, - { - "options": ["auto", "disabled"], - "default": "auto", - "tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", - }, + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatConfig", + display_name="OpenAI ChatGPT Advanced Options", + category="api node/text/OpenAI", + description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", + inputs=[ + IO.Combo.Input( + "truncation", + options=["auto", "disabled"], + default="auto", + tooltip="The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", ), - }, - "optional": { - "max_output_tokens": model_field_to_node_input( - IO.INT, - OpenAICreateResponse, + IO.Int.Input( "max_output_tokens", min=16, default=4096, max=16384, tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens", + optional=True, ), - "instructions": model_field_to_node_input( - IO.STRING, OpenAICreateResponse, "instructions", multiline=True + IO.String.Input( + "instructions", + multiline=True, + optional=True, + tooltip="Instructions for the model on how to generate the response", ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_CHAT_CONFIG").Output(), + ], + ) - def configure( - self, + @classmethod + def execute( + cls, truncation: bool, instructions: Optional[str] = None, max_output_tokens: Optional[int] = None, - ) -> tuple[CreateModelResponseProperties]: + ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. @@ -974,29 +802,27 @@ class OpenAIChatConfig(ComfyNodeABC): They are not exposed as inputs at all to avoid having to manually remove depending on model choice. """ - return ( + return IO.NodeOutput( CreateModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, - ), + ) ) -NODE_CLASS_MAPPINGS = { - "OpenAIDalle2": OpenAIDalle2, - "OpenAIDalle3": OpenAIDalle3, - "OpenAIGPTImage1": OpenAIGPTImage1, - "OpenAIChatNode": OpenAIChatNode, - "OpenAIInputFiles": OpenAIInputFiles, - "OpenAIChatConfig": OpenAIChatConfig, -} +class OpenAIExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + OpenAIDalle2, + OpenAIDalle3, + OpenAIGPTImage1, + OpenAIChatNode, + OpenAIInputFiles, + OpenAIChatConfig, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "OpenAIDalle2": "OpenAI DALL·E 2", - "OpenAIDalle3": "OpenAI DALL·E 3", - "OpenAIGPTImage1": "OpenAI GPT Image 1", - "OpenAIChatNode": "OpenAI ChatGPT", - "OpenAIInputFiles": "OpenAI ChatGPT Input Files", - "OpenAIChatConfig": "OpenAI ChatGPT Advanced Options", -} + +async def comfy_entrypoint() -> OpenAIExtension: + return OpenAIExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index bbc71363a..21013b591 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -18,6 +18,8 @@ from .conversions import ( tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, + text_filepath_to_base64_string, + text_filepath_to_data_uri, trim_video, video_to_base64_string, ) @@ -75,6 +77,8 @@ __all__ = [ "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", + "text_filepath_to_base64_string", + "text_filepath_to_data_uri", "trim_video", "video_to_base64_string", # Validation utilities diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index b59c2bd84..971dc57de 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -1,6 +1,7 @@ import base64 import logging import math +import mimetypes import uuid from io import BytesIO from typing import Optional @@ -12,7 +13,7 @@ from PIL import Image from comfy.utils import common_upscale from comfy_api.latest import Input, InputImpl -from comfy_api.util import VideoContainer, VideoCodec +from comfy_api.util import VideoCodec, VideoContainer from ._helpers import mimetype_to_extension @@ -451,3 +452,19 @@ def resize_mask_to_image( if not allow_gradient: mask = (mask > 0.5).float() return mask + + +def text_filepath_to_base64_string(filepath: str) -> str: + """Converts a text file to a base64 string.""" + with open(filepath, "rb") as f: + file_content = f.read() + return base64.b64encode(file_content).decode("utf-8") + + +def text_filepath_to_data_uri(filepath: str) -> str: + """Converts a text file to a data URI.""" + base64_string = text_filepath_to_base64_string(filepath) + mime_type, _ = mimetypes.guess_type(filepath) + if mime_type is None: + mime_type = "application/octet-stream" + return f"data:{mime_type};base64,{base64_string}" From 4e2110c794b187c9326d604e7c0b0a4fad81148a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:29:08 +0200 Subject: [PATCH 053/231] feat(Pika-API-nodes): use new API client (#10608) --- .../apis/{pika_defs.py => pika_api.py} | 0 comfy_api_nodes/nodes_pika.py | 181 ++++++------------ 2 files changed, 57 insertions(+), 124 deletions(-) rename comfy_api_nodes/apis/{pika_defs.py => pika_api.py} (100%) diff --git a/comfy_api_nodes/apis/pika_defs.py b/comfy_api_nodes/apis/pika_api.py similarity index 100% rename from comfy_api_nodes/apis/pika_defs.py rename to comfy_api_nodes/apis/pika_api.py diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 5bb406a3b..51148211b 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -7,24 +7,23 @@ from __future__ import annotations from io import BytesIO import logging -from typing import Optional, TypeVar +from typing import Optional import torch from typing_extensions import override from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apis import pika_defs -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.apis import pika_api as pika_defs +from comfy_api_nodes.util import ( + validate_string, + download_url_to_video_output, + tensor_to_bytesio, ApiEndpoint, - EmptyRequest, - HttpMethod, - PollingOperation, - SynchronousOperation, + sync_op, + poll_op, ) -from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio -R = TypeVar("R") PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" @@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos" async def execute_task( - initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse], - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, + task_id: str, + cls: type[IO.ComfyNode], ) -> IO.NodeOutput: - task_id = (await initial_operation.execute()).video_id - final_response: pika_defs.PikaVideoResponse = await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{PATH_VIDEO_GET}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=pika_defs.PikaVideoResponse, - ), - completed_statuses=["finished"], - failed_statuses=["failed", "cancelled"], + final_response: pika_defs.PikaVideoResponse = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"), + response_model=pika_defs.PikaVideoResponse, status_extractor=lambda response: (response.status.value if response.status else None), progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), - auth_kwargs=auth_kwargs, - result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None), - node_id=node_id, estimated_duration=60, max_poll_attempts=240, - ).execute() + ) if not final_response.url: error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" logging.error(error_msg) @@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode): resolution=resolution, duration=duration, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaTextToVideoNode(IO.ComfyNode): @@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode): duration: int, aspect_ratio: float, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode): duration=duration, aspectRatio=aspect_ratio, ), - auth_kwargs=auth, content_type="application/x-www-form-urlencoded", ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaScenes(IO.ComfyNode): @@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode): duration=duration, aspectRatio=aspect_ratio, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKASCENES, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKASCENES, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikAdditionsNode(IO.ComfyNode): @@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode): negativePrompt=negative_prompt, seed=seed, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKADDITIONS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaSwapsNode(IO.ComfyNode): @@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode): seed=seed, modifyRegionRoi=region_to_modify if region_to_modify else None, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKASWAPS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKASWAPS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaffectsNode(IO.ComfyNode): @@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode): negative_prompt: str, seed: int, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFFECTS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( pikaffect=pikaffect, promptText=prompt_text, negativePrompt=negative_prompt, @@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode): ), files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaStartEndFrameNode(IO.ComfyNode): @@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode): ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFRAMES, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode): ), files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaApiNodesExtension(ComfyExtension): From e974e554ca23be505b72bc9c1614f4285c1db6e3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 4 Nov 2025 02:59:44 +0800 Subject: [PATCH 054/231] chore: update embedded docs to v0.3.1 (#10614) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4d84b0d3e..856e373de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.28.8 comfyui-workflow-templates==0.2.4 -comfyui-embedded-docs==0.3.0 +comfyui-embedded-docs==0.3.1 torch torchsde torchvision From 958a17199ac519504e390ea0d53295ceb8cbd2c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:08:30 -0800 Subject: [PATCH 055/231] People should update their pytorch versions. (#10618) --- comfy/quant_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 873f173ed..5af6f118e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -418,6 +418,10 @@ def fp8_linear(func, args, kwargs): scale_b=scale_b, out_dtype=out_dtype, ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + if not tensor_2d: output = output.reshape((-1, input_shape[1], weight.shape[0])) From 0652cb8e2d343f68e38285755835c77bda7f6389 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:37:12 -0800 Subject: [PATCH 056/231] Speed up torch.compile (#10620) --- comfy/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 0c8f23848..afe498caa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -71,7 +71,6 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -@torch.compiler.disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This From e199c8cc6758d388792fd66b99e8de832814ff91 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:58:24 -0800 Subject: [PATCH 057/231] Fixes (#10621) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 5af6f118e..835fc4b8d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -131,7 +131,7 @@ class QuantizedTensor(torch.Tensor): self._layout_params = layout_params def __repr__(self): - layout_name = self._layout_type.__name__ + layout_name = self._layout_type param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" @@ -179,7 +179,7 @@ class QuantizedTensor(torch.Tensor): attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) + return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': From 6b88478f9fe0874c0e17468c9fca3a0a84e6c781 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:22:10 -0800 Subject: [PATCH 058/231] Bring back fp8 torch compile performance to what it should be. (#10622) --- comfy/quant_ops.py | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 835fc4b8d..ed7b29963 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -126,7 +126,7 @@ class QuantizedTensor(torch.Tensor): return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata.contiguous() + self._qdata = qdata self._layout_type = layout_type self._layout_params = layout_params @@ -411,7 +411,7 @@ def fp8_linear(func, args, kwargs): try: output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]), + plain_input.reshape(-1, input_shape[2]).contiguous(), weight_t, bias=bias, scale_a=scale_a, @@ -447,6 +447,43 @@ def fp8_linear(func, args, kwargs): return torch.nn.functional.linear(input_tensor, weight, bias) +@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") +def fp8_addmm(func, args, kwargs): + input_tensor = args[1] + weight = args[2] + bias = args[0] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + @register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") @register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") def fp8_func(func, args, kwargs): From 0f4ef3afa0772ad11d6d72ad21fb1e089c2fcf5f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:47:14 -0800 Subject: [PATCH 059/231] This seems to slow things down slightly on Linux. (#10624) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index afe498caa..733bff99d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): try: - if torch.cuda.is_available(): + if torch.cuda.is_available() and comfy.model_management.WINDOWS: from torch.nn.attention import SDPBackend, sdpa_kernel import inspect if "set_priority" in inspect.signature(sdpa_kernel).parameters: From af4b7b5edb339a15aa443e32aefbceac1810baa0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 19:14:20 -0800 Subject: [PATCH 060/231] More fp8 torch.compile regressions fixed. (#10625) --- comfy/quant_ops.py | 54 ++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index ed7b29963..c56e32a73 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -446,6 +446,25 @@ def fp8_linear(func, args, kwargs): return torch.nn.functional.linear(input_tensor, weight, bias) +def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output @register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") def fp8_addmm(func, args, kwargs): @@ -454,25 +473,7 @@ def fp8_addmm(func, args, kwargs): bias = args[0] if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output + return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) a = list(args) if isinstance(args[0], QuantizedTensor): @@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs): return func(*a, **kwargs) +@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") +def fp8_mm(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + @register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") @register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") def fp8_func(func, args, kwargs): From 9c71a667904a049975531f2a7dd55f4a8fc92652 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 5 Nov 2025 02:51:53 +0800 Subject: [PATCH 061/231] chore: update workflow templates to v0.2.11 (#10634) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 856e373de..249c36dee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.8 -comfyui-workflow-templates==0.2.4 +comfyui-workflow-templates==0.2.11 comfyui-embedded-docs==0.3.1 torch torchsde From a389ee01bb7ba5174729906a7f85bd08b5c2cb87 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 5 Nov 2025 08:14:10 +1000 Subject: [PATCH 062/231] caching: Handle None outputs tuple case (#10637) --- comfy_execution/caching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index e077f78b0..326a279fc 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -399,6 +399,8 @@ class RAMPressureCache(LRUCache): ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE def scan_list_for_ram_usage(outputs): nonlocal ram_usage + if outputs is None: + return for output in outputs: if isinstance(output, list): scan_list_for_ram_usage(output) From 7f3e4d486cd77c3ad30eb4714ec18bdaf29e2b5c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:37:50 -0800 Subject: [PATCH 063/231] Limit amount of pinned memory on windows to prevent issues. (#10638) --- comfy/model_management.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79c0dfdb4..0d040e55e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1082,8 +1082,20 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) + +PINNED_MEMORY = {} +TOTAL_PINNED_MEMORY = 0 +if PerformanceFeature.PinnedMem in args.fast: + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 +else: + MAX_PINNED_MEMORY = -1 + def pin_memory(tensor): - if PerformanceFeature.PinnedMem not in args.fast: + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: return False if not is_nvidia(): @@ -1092,13 +1104,21 @@ def pin_memory(tensor): if not is_device_cpu(tensor.device): return False - if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + size = tensor.numel() * tensor.element_size() + if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: + return False + + ptr = tensor.data_ptr() + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: + PINNED_MEMORY[ptr] = size + TOTAL_PINNED_MEMORY += size return True return False def unpin_memory(tensor): - if PerformanceFeature.PinnedMem not in args.fast: + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: return False if not is_nvidia(): @@ -1107,7 +1127,11 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + ptr = tensor.data_ptr() + if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: + TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + if len(PINNED_MEMORY) == 0: + TOTAL_PINNED_MEMORY = 0 return True return False From 265adad858e1f31b66cd3523a02b16f5d34ced52 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Nov 2025 19:42:23 -0500 Subject: [PATCH 064/231] ComfyUI version v0.3.68 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index db48b05c4..25d1a4157 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.67" +__version__ = "0.3.68" diff --git a/pyproject.toml b/pyproject.toml index ab054355c..79ff3f74a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.67" +version = "0.3.68" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 4cd881866bad0cde70273cc123d725693c1f2759 Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 5 Nov 2025 02:10:11 +0100 Subject: [PATCH 065/231] Use single apply_rope function across models (#10547) --- comfy/ldm/flux/layers.py | 4 +- comfy/ldm/flux/math.py | 10 +--- comfy/ldm/lightricks/model.py | 88 ++++++++++++++--------------------- comfy/ldm/qwen_image/model.py | 36 +++++++------- comfy/ldm/wan/model.py | 1 + 5 files changed, 59 insertions(+), 80 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index ef21b416b..a3eab0470 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) - img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) + img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 8deda0d4a..158420290 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -7,15 +7,7 @@ import comfy.model_management def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: - q_shape = q.shape - k_shape = k.shape - - if pe is not None: - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - + q, k = apply_rope(q, k, pe) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index def365ba7..5bcba998b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -3,12 +3,11 @@ from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -from einops import rearrange import math from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords - +from comfy.ldm.flux.math import apply_rope1 def get_timestep_embedding( timesteps: torch.Tensor, @@ -238,20 +237,6 @@ class FeedForward(nn.Module): return self.net(x) -def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out - - class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() @@ -281,8 +266,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rotary_emb(q, pe) - k = apply_rotary_emb(k, pe) + q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) + k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa + norm_x = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa) + attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_result, gate_msa) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp - x += self.ff(y) * gate_mlp + norm_x = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp) + ff_result = self.ff(y) + x.addcmul_(ff_result, gate_mlp) return x @@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 #self.dtype + dtype = torch.float32 + device = indices_grid.device + # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 - start = 1 - end = theta - device = fractional_positions.device + # Compute frequencies and apply cos/sin + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + cos_vals = freqs.cos().repeat_interleave(2, dim=-1) + sin_vals = freqs.sin().repeat_interleave(2, dim=-1) - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - - indices = indices * math.pi / 2 - - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + # Pad if dim is not divisible by 6 if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + padding_size = dim % 6 + cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) + sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) + + # Reshape and extract one value per pair (since repeat_interleave duplicates each value) + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + + # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension + freqs_cis = torch.stack([ + torch.stack([cos_vals, -sin_vals], dim=-1), + torch.stack([sin_vals, cos_vals], dim=-1) + ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + + return freqs_cis.to(out_dtype) class LTXVModel(torch.nn.Module): @@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module): shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation - x = x * (1 + scale) + shift + x = torch.addcmul(x, x, scale).add_(shift) x = self.proj_out(x) x = self.patchifier.unpatchify( diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index b9f60c2b7..81d3ee7c0 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope1 class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -134,33 +135,34 @@ class Attention(nn.Module): image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] - img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) - img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) - img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + # Project and reshape to BHND format (batch, heads, seq, dim) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5ec1511ce..a9d5e10d9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module): # assert e[0].dtype == torch.float32 # self-attention + x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) From c4a6b389de1014471a75a46ee57d2fdac4f8df93 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 4 Nov 2025 19:47:35 -0800 Subject: [PATCH 066/231] Lower ltxv mem usage to what it was before previous pr. (#10643) Bring back qwen behavior to what it was before previous pr. --- comfy/ldm/lightricks/model.py | 22 +++++++++++----------- comfy/ldm/qwen_image/model.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 5bcba998b..593f7940f 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -291,17 +291,17 @@ class BasicTransformerBlock(nn.Module): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - norm_x = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa) - attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) - x.addcmul_(attn1_result, gate_msa) + attn1_input = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) + attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_input, gate_msa) + del attn1_input x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - norm_x = comfy.ldm.common_dit.rms_norm(x) - y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp) - ff_result = self.ff(y) - x.addcmul_(ff_result, gate_mlp) + y = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) + x.addcmul_(self.ff(y), gate_mlp) return x @@ -336,8 +336,8 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2 sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension freqs_cis = torch.stack([ @@ -345,7 +345,7 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2 torch.stack([sin_vals, cos_vals], dim=-1) ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] - return freqs_cis.to(out_dtype) + return freqs_cis class LTXVModel(torch.nn.Module): diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 81d3ee7c0..e5d0d17c1 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -415,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous() + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) From bda0eb2448135797d5a72f7236ce26d07e555baf Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:16:00 +0200 Subject: [PATCH 067/231] feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645) --- comfy_api_nodes/apis/PixverseController.py | 17 - comfy_api_nodes/apis/PixverseDto.py | 57 - comfy_api_nodes/apis/client.py | 981 ------------------ comfy_api_nodes/nodes_rodin.py | 196 ++-- comfy_api_nodes/util/client.py | 4 +- comfy_api_nodes/util/download_helpers.py | 2 +- .../{apis => util}/request_logger.py | 4 +- comfy_api_nodes/util/upload_helpers.py | 2 +- 8 files changed, 75 insertions(+), 1188 deletions(-) delete mode 100644 comfy_api_nodes/apis/PixverseController.py delete mode 100644 comfy_api_nodes/apis/PixverseDto.py delete mode 100644 comfy_api_nodes/apis/client.py rename comfy_api_nodes/{apis => util}/request_logger.py (100%) diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py deleted file mode 100644 index 310c0f546..000000000 --- a/comfy_api_nodes/apis/PixverseController.py +++ /dev/null @@ -1,17 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from . import PixverseDto - - -class ResponseData(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py deleted file mode 100644 index 323c38e96..000000000 --- a/comfy_api_nodes/apis/PixverseDto.py +++ /dev/null @@ -1,57 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel, Field - - -class V2OpenAPII2VResp(BaseModel): - video_id: Optional[int] = Field(None, description='Video_id') - - -class V2OpenAPIT2VReq(BaseModel): - aspect_ratio: str = Field( - ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] - ) - duration: int = Field( - ..., - description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', - examples=[5], - ) - model: str = Field( - ..., description='Model version (only supports v3.5)', examples=['v3.5'] - ) - motion_mode: Optional[str] = Field( - 'normal', - description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', - examples=['normal'], - ) - negative_prompt: Optional[str] = Field( - None, description='Negative prompt\n', max_length=2048 - ) - prompt: str = Field(..., description='Prompt', max_length=2048) - quality: str = Field( - ..., - description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', - examples=['540p'], - ) - seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') - style: Optional[str] = Field( - None, - description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', - examples=['anime'], - ) - template_id: Optional[int] = Field( - None, - description='Template ID (template_id must be activated before use)', - examples=[302325299692608], - ) - water_mark: Optional[bool] = Field( - False, - description='Watermark (true: add watermark, false: no watermark)', - examples=[False], - ) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py deleted file mode 100644 index bdaddcc88..000000000 --- a/comfy_api_nodes/apis/client.py +++ /dev/null @@ -1,981 +0,0 @@ -""" -API Client Framework for api.comfy.org. - -This module provides a flexible framework for making API requests from ComfyUI nodes. -It supports both synchronous and asynchronous API operations with proper type validation. - -Key Components: --------------- -1. ApiClient - Handles HTTP requests with authentication and error handling -2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models -3. ApiOperation - Executes a single synchronous API operation - -Usage Examples: --------------- - -# Example 1: Synchronous API Operation -# ------------------------------------ -# For a simple API call that returns the result immediately: - -# 1. Create the API client -api_client = ApiClient( - base_url="https://api.example.com", - auth_token="your_auth_token_here", - comfy_api_key="your_comfy_api_key_here", - timeout=30.0, - verify_ssl=True -) - -# 2. Define the endpoint -user_info_endpoint = ApiEndpoint( - path="/v1/users/me", - method=HttpMethod.GET, - request_model=EmptyRequest, # No request body needed - response_model=UserProfile, # Pydantic model for the response - query_params=None -) - -# 3. Create the request object -request = EmptyRequest() - -# 4. Create and execute the operation -operation = ApiOperation( - endpoint=user_info_endpoint, - request=request -) -user_profile = await operation.execute(client=api_client) # Returns immediately with the result - - -# Example 2: Asynchronous API Operation with Polling -# ------------------------------------------------- -# For an API that starts a task and requires polling for completion: - -# 1. Define the endpoints (initial request and polling) -generate_image_endpoint = ApiEndpoint( - path="/v1/images/generate", - method=HttpMethod.POST, - request_model=ImageGenerationRequest, - response_model=TaskCreatedResponse, - query_params=None -) - -check_task_endpoint = ApiEndpoint( - path="/v1/tasks/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=ImageGenerationResult, - query_params=None -) - -# 2. Create the request object -request = ImageGenerationRequest( - prompt="a beautiful sunset over mountains", - width=1024, - height=1024, - num_images=1 -) - -# 3. Create and execute the polling operation -operation = PollingOperation( - initial_endpoint=generate_image_endpoint, - initial_request=request, - poll_endpoint=check_task_endpoint, - task_id_field="task_id", - status_field="status", - completed_statuses=["completed"], - failed_statuses=["failed", "error"] -) - -# This will make the initial request and then poll until completion -result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done -""" - -from __future__ import annotations -import aiohttp -import asyncio -import logging -import io -import os -import socket -from aiohttp.client_exceptions import ClientError, ClientResponseError -from typing import Type, Optional, Any, TypeVar, Generic, Callable -from enum import Enum -import json -from urllib.parse import urljoin, urlparse -from pydantic import BaseModel, Field -import uuid # For generating unique operation IDs - -from server import PromptServer -from comfy.cli_args import args -from comfy import utils -from . import request_logger - -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R", bound=BaseModel) -P = TypeVar("P", bound=BaseModel) # For poll response - -PROGRESS_BAR_MAX = 100 - - -class NetworkError(Exception): - """Base exception for network-related errors with diagnostic information.""" - pass - - -class LocalNetworkError(NetworkError): - """Exception raised when local network connectivity issues are detected.""" - pass - - -class ApiServerError(NetworkError): - """Exception raised when the API server is unreachable but internet is working.""" - pass - - -class EmptyRequest(BaseModel): - """Base class for empty request bodies. - For GET requests, fields will be sent as query parameters.""" - - pass - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -class HttpMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - -class ApiClient: - """ - Client for making HTTP requests to an API with authentication, error handling, and retry logic. - """ - - def __init__( - self, - base_url: str, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - timeout: float = 3600.0, - verify_ssl: bool = True, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[tuple[int, ...]] = None, - session: Optional[aiohttp.ClientSession] = None, - ): - self.base_url = base_url - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - self.timeout = timeout - self.verify_ssl = verify_ssl - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), - # 500, 502, 503, 504 (Server Errors) - self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) - self._session: Optional[aiohttp.ClientSession] = session - self._owns_session = session is None # Track if we have to close it - - @staticmethod - def _generate_operation_id(path: str) -> str: - """Generates a unique operation ID for logging.""" - return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" - - @staticmethod - def _create_json_payload_args( - data: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - return { - "json": data, - "headers": headers, - } - - def _create_form_data_args( - self, - data: dict[str, Any] | None, - files: dict[str, Any] | None, - headers: Optional[dict[str, str]] = None, - multipart_parser: Callable | None = None, - ) -> dict[str, Any]: - if headers and "Content-Type" in headers: - del headers["Content-Type"] - - if multipart_parser and data: - data = multipart_parser(data) - - if isinstance(data, aiohttp.FormData): - form = data # If the parser already returned a FormData, pass it through - else: - form = aiohttp.FormData(default_to_multipart=True) - if data: # regular text fields - for k, v in data.items(): - if v is None: - continue # aiohttp fails to serialize "None" values - # aiohttp expects strings or bytes; convert enums etc. - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) - - if files: - file_iter = files if isinstance(files, list) else files.items() - for field_name, file_obj in file_iter: - if file_obj is None: - continue # aiohttp fails to serialize "None" values - # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple - if isinstance(file_obj, tuple): - filename, file_value, content_type = self._unpack_tuple(file_obj) - else: - file_value = file_obj - filename = getattr(file_obj, "name", field_name) - content_type = "application/octet-stream" - - form.add_field( - name=field_name, - value=file_value, - filename=filename, - content_type=content_type, - ) - return {"data": form, "headers": headers or {}} - - @staticmethod - def _create_urlencoded_form_data_args( - data: dict[str, Any], - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - headers = headers or {} - headers["Content-Type"] = "application/x-www-form-urlencoded" - return { - "data": data, - "headers": headers, - } - - def get_headers(self) -> dict[str, str]: - """Get headers for API requests, including authentication if available""" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" - elif self.comfy_api_key: - headers["X-API-KEY"] = self.comfy_api_key - - return headers - - async def _check_connectivity(self, target_url: str) -> dict[str, bool]: - """ - Check connectivity to determine if network issues are local or server-related. - - Args: - target_url: URL to check connectivity to - - Returns: - Dictionary with connectivity status details - """ - results = { - "internet_accessible": False, - "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, - } - timeout = aiohttp.ClientTimeout(total=5.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - try: - async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: - results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True - return results # cannot reach the internet – early exit - - # Now check API health endpoint - parsed = urlparse(target_url) - health_url = f"{parsed.scheme}://{parsed.netloc}/health" - try: - async with session.get(health_url, ssl=self.verify_ssl) as resp: - results["api_accessible"] = resp.status < 500 - except ClientError: - pass # leave as False - - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] - return results - - async def request( - self, - method: str, - path: str, - params: Optional[dict[str, Any]] = None, - data: Optional[dict[str, Any]] = None, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - headers: Optional[dict[str, str]] = None, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - retry_count: int = 0, # Used internally for tracking retries - ) -> dict[str, Any]: - """ - Make an HTTP request to the API with automatic retries for transient errors. - - Args: - method: HTTP method (GET, POST, etc.) - path: API endpoint path (will be joined with base_url) - params: Query parameters - data: body data - files: Files to upload - headers: Additional headers - content_type: Content type of the request. Defaults to application/json. - retry_count: Internal parameter for tracking retries, do not set manually - - Returns: - Parsed JSON response - - Raises: - LocalNetworkError: If local network connectivity issues are detected - ApiServerError: If the API server is unreachable but internet is working - Exception: For other request failures - """ - - # Build full URL and merge headers - relative_path = path.lstrip("/") - url = urljoin(self.base_url, relative_path) - self._check_auth(self.auth_token, self.comfy_api_key) - - request_headers = self.get_headers() - if headers: - request_headers.update(headers) - if files: - request_headers.pop("Content-Type", None) - if params: - params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values - - logging.debug("[DEBUG] Request Headers: %s", request_headers) - logging.debug("[DEBUG] Files: %s", files) - logging.debug("[DEBUG] Params: %s", params) - logging.debug("[DEBUG] Data: %s", data) - - if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) - elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) - else: - payload_args = self._create_json_payload_args(data, request_headers) - - operation_id = self._generate_operation_id(path) - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=request_headers, - request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]", - ) - - session = await self._get_session() - try: - async with session.request( - method, - url, - params=params, - ssl=self.verify_ssl, - **payload_args, - ) as resp: - if resp.status >= 400: - try: - error_data = await resp.json() - except (aiohttp.ContentTypeError, json.JSONDecodeError): - error_data = await resp.text() - - return await self._handle_http_error( - ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), - operation_id, - method, - url, - params, - data, - files, - headers, - content_type, - multipart_parser, - retry_count=retry_count, - response_content=error_data, - ) - - # Success – parse JSON (safely) and log - try: - payload = await resp.json() - response_content_to_log = payload - except (aiohttp.ContentTypeError, json.JSONDecodeError): - payload = {} - response_content_to_log = await resp.text() - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - return payload - - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: - # Treat as *connection* problem – optionally retry, else escalate - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, - self.max_retries, str(e)) - await asyncio.sleep(delay) - return await self.request( - method, - path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - # One final connectivity check for diagnostics - connectivity = await self._check_connectivity(self.base_url) - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - @staticmethod - def _check_auth(auth_token, comfy_api_key): - """Verify that an auth token is present or comfy_api_key is present""" - if auth_token is None and comfy_api_key is None: - raise Exception("Unauthorized: Please login first to use this node.") - return auth_token or comfy_api_key - - @staticmethod - async def upload_file( - upload_url: str, - file: io.BytesIO | str, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> aiohttp.ClientResponse: - """Upload a file to the API with retry logic. - - Args: - upload_url: The URL to upload to - file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - content_type: Optional mime type to set for the upload - max_retries: Maximum number of retry attempts - retry_delay: Initial delay between retries in seconds - retry_backoff_factor: Multiplier for the delay after each retry - """ - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. - skip_auto_headers.add("Content-Type") - - # Extract file bytes - if isinstance(file, io.BytesIO): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("File must be BytesIO or str path") - - parsed = urlparse(upload_url) - basename = os.path.basename(parsed.path) or parsed.netloc or "upload" - operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers, - request_data=f"[File data {len(data)} bytes]", - ) - - delay = retry_delay - for attempt in range(max_retries + 1): - try: - timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.put( - upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, - ) as resp: - resp.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - return resp - except (ClientError, asyncio.TimeoutError) as e: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if hasattr(e, "headers") else None, - response_content=None, - error_message=f"{type(e).__name__}: {str(e)}", - ) - if attempt < max_retries: - logging.warning( - "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) - ) - await asyncio.sleep(delay) - delay *= retry_backoff_factor - else: - raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - - async def _handle_http_error( - self, - exc: ClientResponseError, - operation_id: str, - *req_meta, - retry_count: int, - response_content: dict | str = "", - ) -> dict[str, Any]: - status_code = exc.status - if status_code == 401: - user_friendly = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_friendly = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_friendly = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_friendly = "Rate Limit Exceeded: Please try again later." - else: - if isinstance(response_content, dict): - if "error" in response_content and "message" in response_content["error"]: - user_friendly = f"API Error: {response_content['error']['message']}" - if "type" in response_content["error"]: - user_friendly += f" (Type: {response_content['error']['type']})" - else: # Handle cases where error is just a JSON dict with unknown format - user_friendly = f"API Error: {json.dumps(response_content)}" - else: - if len(response_content) < 200: # Arbitrary limit for display - user_friendly = f"API Error (raw): {response_content}" - else: - user_friendly = f"API Error (raw, status {response_content})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=req_meta[0], - request_url=req_meta[1], - response_status_code=exc.status, - response_headers=dict(req_meta[5]) if req_meta[5] else None, - response_content=response_content, - error_message=f"HTTP Error {exc.status}", - ) - - logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code) - if response_content: - logging.debug("[DEBUG] Response content: %s", response_content) - - # Retry if eligible - if status_code in self.retry_status_codes and retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - "HTTP error %s. Retrying in %.2fs (%s/%s)", - status_code, - delay, - retry_count + 1, - self.max_retries, - ) - await asyncio.sleep(delay) - return await self.request( - req_meta[0], # method - req_meta[1].replace(self.base_url, ""), # path - params=req_meta[2], - data=req_meta[3], - files=req_meta[4], - headers=req_meta[5], - content_type=req_meta[6], - multipart_parser=req_meta[7], - retry_count=retry_count + 1, - ) - - raise Exception(user_friendly) from exc - - @staticmethod - def _unpack_tuple(t): - """Helper to normalise (filename, file, content_type) tuples.""" - if len(t) == 3: - return t - elif len(t) == 2: - return t[0], t[1], "application/octet-stream" - else: - raise ValueError("files tuple must be (filename, file[, content_type])") - - async def _get_session(self) -> aiohttp.ClientSession: - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=self.timeout) - self._session = aiohttp.ClientSession(timeout=timeout) - self._owns_session = True - return self._session - - async def close(self) -> None: - if self._owns_session and self._session and not self._session.closed: - await self._session.close() - - async def __aenter__(self) -> "ApiClient": - """Allow usage as async‑context‑manager – ensures clean teardown""" - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - - -class ApiEndpoint(Generic[T, R]): - """Defines an API endpoint with its request and response types""" - - def __init__( - self, - path: str, - method: HttpMethod, - request_model: Type[T], - response_model: Type[R], - query_params: Optional[dict[str, Any]] = None, - ): - """Initialize an API endpoint definition. - - Args: - path: The URL path for this endpoint, can include placeholders like {id} - method: The HTTP method to use (GET, POST, etc.) - request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint - response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint - query_params: Optional dictionary of query parameters to include in the request - """ - self.path = path - self.method = method - self.request_model = request_model - self.response_model = response_model - self.query_params = query_params or {} - - -class SynchronousOperation(Generic[T, R]): - """Represents a single synchronous API operation.""" - - def __init__( - self, - endpoint: ApiEndpoint[T, R], - request: T, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - timeout: float = 7200.0, - verify_ssl: bool = True, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> None: - self.endpoint = endpoint - self.request = request - self.files = files - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.timeout = timeout - self.verify_ssl = verify_ssl - self.content_type = content_type - self.multipart_parser = multipart_parser - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - try: - request_dict: Optional[dict[str, Any]] - if isinstance(self.request, EmptyRequest): - request_dict = None - else: - request_dict = self.request.model_dump(exclude_none=True) - for k, v in list(request_dict.items()): - if isinstance(v, Enum): - request_dict[k] = v.value - - logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path) - logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2)) - logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params) - - response_json = await client.request( - self.endpoint.method.value, - self.endpoint.path, - params=self.endpoint.query_params, - data=request_dict, - files=self.files, - content_type=self.content_type, - multipart_parser=self.multipart_parser, - ) - - logging.debug("=" * 50) - logging.debug("[DEBUG] RESPONSE DETAILS:") - logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2)) - logging.debug("=" * 50) - - parsed_response = self.endpoint.response_model.model_validate(response_json) - logging.debug("[DEBUG] Parsed Response: %s", parsed_response) - return parsed_response - finally: - if owns_client: - await client.close() - - -class TaskStatus(str, Enum): - """Enum for task status values""" - - COMPLETED = "completed" - FAILED = "failed" - PENDING = "pending" - - -class PollingOperation(Generic[T, R]): - """Represents an asynchronous API operation that requires polling for completion.""" - - def __init__( - self, - poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list[str], - failed_statuses: list[str], - *, - status_extractor: Callable[[R], Optional[str]], - progress_extractor: Callable[[R], Optional[float]] | None = None, - result_url_extractor: Callable[[R], Optional[str]] | None = None, - price_extractor: Callable[[R], Optional[float]] | None = None, - request: Optional[T] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - poll_interval: float = 5.0, - max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) - max_retries: int = 3, # Max retries per individual API call - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - estimated_duration: Optional[float] = None, - node_id: Optional[str] = None, - ) -> None: - self.poll_endpoint = poll_endpoint - self.request = request - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.poll_interval = poll_interval - self.max_poll_attempts = max_poll_attempts - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - self.estimated_duration = estimated_duration - self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) - self.progress_extractor = progress_extractor - self.result_url_extractor = result_url_extractor - self.price_extractor = price_extractor - self.node_id = node_id - self.completed_statuses = completed_statuses - self.failed_statuses = failed_statuses - self.final_response: Optional[R] = None - self.extracted_price: Optional[float] = None - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - try: - return await self._poll_until_complete(client) - finally: - if owns_client: - await client.close() - - def _display_text_on_node(self, text: str): - if not self.node_id: - return - if self.extracted_price is not None: - text = f"Price: ${self.extracted_price}\n{text}" - PromptServer.instance.send_progress_text(text, self.node_id) - - def _display_time_progress_on_node(self, time_completed: int | float): - if not self.node_id: - return - if self.estimated_duration is not None: - remaining = max(0, int(self.estimated_duration) - time_completed) - message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" - else: - message = f"Task in progress: {time_completed}s" - self._display_text_on_node(message) - - def _check_task_status(self, response: R) -> TaskStatus: - try: - status = self.status_extractor(response) - if status in self.completed_statuses: - return TaskStatus.COMPLETED - if status in self.failed_statuses: - return TaskStatus.FAILED - return TaskStatus.PENDING - except Exception as e: - logging.error("Error extracting status: %s", e) - return TaskStatus.PENDING - - async def _poll_until_complete(self, client: ApiClient) -> R: - """Poll until the task is complete""" - consecutive_errors = 0 - max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors - - if self.progress_extractor: - progress = utils.ProgressBar(PROGRESS_BAR_MAX) - - status = TaskStatus.PENDING - for poll_count in range(1, self.max_poll_attempts + 1): - try: - logging.debug("[DEBUG] Polling attempt #%s", poll_count) - - request_dict = None if self.request is None else self.request.model_dump(exclude_none=True) - - if poll_count == 1: - logging.debug( - "[DEBUG] Poll Request: %s %s", - self.poll_endpoint.method.value, - self.poll_endpoint.path, - ) - logging.debug( - "[DEBUG] Poll Request Data: %s", - json.dumps(request_dict, indent=2) if request_dict else "None", - ) - - # Query task status - resp = await client.request( - self.poll_endpoint.method.value, - self.poll_endpoint.path, - params=self.poll_endpoint.query_params, - data=request_dict, - ) - consecutive_errors = 0 # reset on success - response_obj: R = self.poll_endpoint.response_model.model_validate(resp) - - # Check if task is complete - status = self._check_task_status(response_obj) - logging.debug("[DEBUG] Task Status: %s", status) - - # If progress extractor is provided, extract progress - if self.progress_extractor: - new_progress = self.progress_extractor(response_obj) - if new_progress is not None: - progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) - - if self.price_extractor: - price = self.price_extractor(response_obj) - if price is not None: - self.extracted_price = price - - if status == TaskStatus.COMPLETED: - message = "Task completed successfully" - if self.result_url_extractor: - result_url = self.result_url_extractor(response_obj) - if result_url: - message = f"Result URL: {result_url}" - logging.debug("[DEBUG] %s", message) - self._display_text_on_node(message) - self.final_response = response_obj - if self.progress_extractor: - progress.update(100) - return self.final_response - if status == TaskStatus.FAILED: - message = f"Task failed: {json.dumps(resp)}" - logging.error("[DEBUG] %s", message) - raise Exception(message) - logging.debug("[DEBUG] Task still pending, continuing to poll...") - # Task pending – wait - for i in range(int(self.poll_interval)): - self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) - await asyncio.sleep(1) - - except (LocalNetworkError, ApiServerError, NetworkError) as e: - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - raise Exception( - f"Polling aborted after {consecutive_errors} network errors: {str(e)}" - ) from e - logging.warning( - "Network error (%s/%s): %s", - consecutive_errors, - max_consecutive_errors, - str(e), - ) - await asyncio.sleep(self.poll_interval) - except Exception as e: - # For other errors, increment count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" - ) from e - - logging.error("[DEBUG] Polling error: %s", str(e)) - logging.warning( - "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.", - poll_count, - self.max_poll_attempts, - str(e), - self.poll_interval, - ) - await asyncio.sleep(self.poll_interval) - - # If we've exhausted all polling attempts - raise Exception( - f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " - "The operation may still be running on the server but is taking longer than expected." - ) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index ad4029236..e60e7a6d6 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from __future__ import annotations from inspect import cleandoc import folder_paths as comfy_paths -import aiohttp import os -import asyncio import logging import math from typing import Optional @@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import ( Rodin3DDownloadResponse, JobStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + sync_op, + poll_op, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_bytesio, ) from comfy_api.latest import ComfyExtension, IO @@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): async def create_generate_task( + cls: type[IO.ComfyNode], images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", - TAPose = False, - auth_kwargs: Optional[dict[str, str]] = None, + ta_pose: bool = False, ): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=Rodin3DGenerateRequest( seed=seed, tier=tier, material=material, quality_override=quality_override, mesh_mode=mesh_mode, - TAPose=TAPose, + TAPose=ta_pose, ), files=[ ( @@ -159,11 +152,8 @@ async def create_generate_task( for image in images if image is not None ], content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response = await operation.execute() - if hasattr(response, "error"): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" logging.error(error_message) @@ -187,74 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" +def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: + if not response.jobs: + return None + completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) + return int((completed_count / len(response.jobs)) * 100) -async def poll_for_task_status( - subscription_key, auth_kwargs: Optional[dict[str, str]] = None, -) -> Rodin3DCheckStatusResponse: - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/status", - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=check_rodin_status, - poll_interval=3.0, - auth_kwargs=auth_kwargs, - ) + +async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return await poll_operation.execute() - - -async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/download", - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest(task_uuid=uuid), - auth_kwargs=auth_kwargs, + return await poll_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), + response_model=Rodin3DCheckStatusResponse, + data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + status_extractor=check_rodin_status, + progress_extractor=extract_progress, ) - return await operation.execute() -async def download_files(url_list, task_uuid): +async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + return await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), + response_model=Rodin3DDownloadResponse, + data=Rodin3DDownloadRequest(task_uuid=uuid), + monitor_progress=False, + ) + + +async def download_files(url_list, task_uuid: str): result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): - model_file_path = os.path.join(result_folder_name, i.name) - logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(i.url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e)) - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) + for i in url_list.list: + file_path = os.path.join(save_path, i.name) + if file_path.endswith(".glb"): + model_file_path = os.path.join(result_folder_name, i.name) + await download_url_to_bytesio(i.url, file_path) return model_file_path @@ -276,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -294,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -332,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -350,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -388,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -400,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode): Material_Type, Polygon_count, ) -> IO.NodeOutput: - tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, - tier=tier, + tier="Smooth", mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -451,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -461,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode): Images, Seed, ) -> IO.NodeOutput: - tier = "Sketch" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - material_type = "PBR" - quality_override = 18000 - mesh_mode = "Quad" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, - material=material_type, - quality_override=quality_override, - tier=tier, - mesh_mode=mesh_mode, - auth_kwargs=auth, + material="PBR", + quality_override=18000, + tier="Sketch", + mesh_mode="Quad", ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -522,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -541,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - TAPose=TAPose, - auth_kwargs=auth, + ta_pose=TAPose, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 65bb35f0f..2d5dcd648 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -16,9 +16,9 @@ from pydantic import BaseModel from comfy import utils from comfy_api.latest import IO -from comfy_api_nodes.apis import request_logger from server import PromptServer +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, @@ -77,7 +77,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 364874bed..14207dc68 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/util/request_logger.py similarity index 100% rename from comfy_api_nodes/apis/request_logger.py rename to comfy_api_nodes/util/request_logger.py index c6974d35c..ac52e2eab 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,11 +1,11 @@ from __future__ import annotations -import os import datetime +import hashlib import json import logging +import os import re -import hashlib from typing import Any import folder_paths diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 7bfc61704..632450d9b 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -13,8 +13,8 @@ from pydantic import BaseModel, Field from comfy_api.latest import IO, Input from comfy_api.util import VideoCodec, VideoContainer -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt from .client import ( ApiEndpoint, From 97f198e4215680a83749ba95849f3cdcfa7aa64a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:07:35 -0800 Subject: [PATCH 068/231] Fix qwen controlnet regression. (#10657) --- comfy/ldm/qwen_image/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index 92ac3cf0a..a6d408104 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) From 1d69245981f9fb3861018613246042296d887dd3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:08:13 -0800 Subject: [PATCH 069/231] Enable pinned memory by default on Nvidia. (#10656) Removed the --fast pinned_memory flag. You can use --disable-pinned-memory to disable it. Please report if it causes any issues. --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3947e62a8..2f30b72d2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -145,10 +145,11 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") + parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 0d040e55e..4d13c52c1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1085,22 +1085,21 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 -if PerformanceFeature.PinnedMem in args.fast: - if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% - else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 -else: - MAX_PINNED_MEMORY = -1 +MAX_PINNED_MEMORY = -1 +if not args.disable_pinned_memory: + if is_nvidia(): + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 + logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False @@ -1121,9 +1120,6 @@ def unpin_memory(tensor): if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False From 09dc24c8a982776abd5cb2f71e3d041139e1d5b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:11:15 -0800 Subject: [PATCH 070/231] Pinned mem also seems to work on AMD. (#10658) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4d13c52c1..7a30c4bec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1087,7 +1087,7 @@ PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: - if is_nvidia(): + if is_nvidia() or is_amd(): if WINDOWS: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% else: From e05c90712670fa4a2ffebd44046fc78747193a36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:11:30 -0800 Subject: [PATCH 071/231] Clarify release cycle. (#10667) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4204777e9..8142f595b 100644 --- a/README.md +++ b/README.md @@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Release Process -ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: +ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) + - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** From eb1c42f6498ce44aef4dbed3bb665ac98a28254d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:24:28 -0800 Subject: [PATCH 072/231] Tell users they need to upload their logs in bug reports. (#10671) --- .github/ISSUE_TEMPLATE/bug-report.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 3cf2717b7..6556677e0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,15 @@ body: Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report. - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - `--disable-all-custom-nodes` command line argument. + `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version. - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + ## Very Important + + Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: From cf97b033ee80cf245b4592d42f89e6de67e409a4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:20:48 +1000 Subject: [PATCH 073/231] mm: guard against double pin and unpin explicitly (#10672) As commented, if you let cuda be the one to detect double pin/unpinning it actually creates an asyc GPU error. --- comfy/model_management.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7a30c4bec..a13b24cea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1103,6 +1103,12 @@ def pin_memory(tensor): if not is_device_cpu(tensor.device): return False + if tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + size = tensor.numel() * tensor.element_size() if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1123,6 +1129,12 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False + if not tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) From a1a70362ca376cff05a0514e0ce771ab26d92fd9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 7 Nov 2025 08:15:05 -0800 Subject: [PATCH 074/231] Only unpin tensor if it was pinned by ComfyUI (#10677) --- comfy/model_management.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a13b24cea..7012df858 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1129,13 +1129,18 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - if not tensor.is_pinned(): - #NOTE: Cuda does detect when a tensor is already pinned and would - #error below, but there are proven cases where this also queues an error - #on the GPU async. So dont trust the CUDA API and guard here + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + + size_stored = PINNED_MEMORY.get(ptr, None) + if size_stored is None: + logging.warning("Tried to unpin tensor not pinned by ComfyUI") + return False + + if size != size_stored: + logging.warning("Size of pinned tensor changed") return False - ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) if len(PINNED_MEMORY) == 0: From 2abd2b5c2049a9625b342bcb7decedd5d1645f66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 8 Nov 2025 12:52:02 -0800 Subject: [PATCH 075/231] Make ScaleROPE node work on Flux. (#10686) --- comfy/ldm/flux/model.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..b9d36f202 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -210,7 +210,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -222,10 +222,22 @@ class Flux(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -241,7 +253,7 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x) + img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 From e632e5de281b91dd7199636dd6d82126fbfb07d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:06:39 -0800 Subject: [PATCH 076/231] Add logging for model unloading. (#10692) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5a31a8734..17e06a869 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -909,6 +909,7 @@ class ModelPatcher: self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed + logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): From dea899f22125d38a8b48147d6cce89a2b659fdeb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:51:33 -0800 Subject: [PATCH 077/231] Unload weights if vram usage goes up between runs. (#10690) --- comfy/model_management.py | 11 +++++++++-- comfy/model_patcher.py | 20 +++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858..a4410f2ec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,7 +503,11 @@ class LoadedModel: use_more_vram = lowvram_model_memory if use_more_vram == 0: use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + if use_more_vram > 0: + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + else: + self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_free_mem = get_free_memory(torch_dev) + loaded_memory lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 17e06a869..68b0a9192 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -843,7 +843,7 @@ class ModelPatcher: self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 @@ -887,13 +887,19 @@ class ModelPatcher: module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, weight_key) - m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, bias_key) - m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True if cast_weight: From c350009236e5d172a3050c04043ea70a301378ca Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:52:11 +1000 Subject: [PATCH 078/231] ops: Put weight cast on the offload stream (#10697) This needs to be on the offload stream. This reproduced a black screen with low resolution images on a slow bus when using FP8. --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 733bff99d..96dffa85d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -110,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.bias_function: bias = f(bias) - weight = weight.to(dtype=dtype) - if weight_has_function: + if weight_has_function or weight.dtype != dtype: with wf_context: + weight = weight.to(dtype=dtype) for f in s.weight_function: weight = f(weight) From 5ebcab3c7d974963a89cecd37296a22fdb73bd2b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:35:29 -0800 Subject: [PATCH 079/231] Update CI workflow to remove dead macOS runner. (#10704) * Update CI workflow to remove dead macOS runner. * revert * revert --- .github/workflows/test-ci.yml | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 418dca0ab..1660ec8e3 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -21,14 +21,15 @@ jobs: fail-fast: false matrix: # os: [macos, linux, windows] - os: [macos, linux] - python_version: ["3.9", "3.10", "3.11", "3.12"] + # os: [macos, linux] + os: [linux] + python_version: ["3.10", "3.11", "3.12"] cuda_version: ["12.1"] torch_version: ["stable"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: "" @@ -73,14 +74,15 @@ jobs: strategy: fail-fast: false matrix: - os: [macos, linux] + # os: [macos, linux] + os: [linux] python_version: ["3.11"] cuda_version: ["12.1"] torch_version: ["nightly"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: "" From 119941174704081a16a4c3f303d99f2fb1e95cde Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:33:30 -0800 Subject: [PATCH 080/231] Don't pin tensor if not a torch.nn.parameter.Parameter (#10718) --- comfy/model_management.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index a4410f2ec..d8913082a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1107,6 +1107,9 @@ def pin_memory(tensor): if MAX_PINNED_MEMORY <= 0: return False + if type(tensor) is not torch.nn.parameter.Parameter: + return False + if not is_device_cpu(tensor.device): return False @@ -1116,6 +1119,9 @@ def pin_memory(tensor): #on the GPU async. So dont trust the CUDA API and guard here return False + if not tensor.is_contiguous(): + return False + size = tensor.numel() * tensor.element_size() if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False From e1d85e7577d8f6355bd4cb3449bcb0a7e5f80cb8 Mon Sep 17 00:00:00 2001 From: Qiacheng Li Date: Wed, 12 Nov 2025 12:21:05 -0800 Subject: [PATCH 081/231] Update README.md for Intel Arc GPU installation, remove IPEX (#10729) IPEX is no longer needed for Intel Arc GPUs. Removing instruction to setup ipex. --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 8142f595b..9e28803a2 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ RDNA 4 (RX 9000 series): ### Intel GPUs (Windows and Linux) -(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) +Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) 1. To install PyTorch xpu, use the following command: @@ -252,10 +252,6 @@ This is the command to install the Pytorch xpu nightly which might have some per ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu``` -(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance. - -1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. - ### NVIDIA Nvidia users should install stable pytorch using this command: From 18e7d6dba5f1012d4cf09e8f777dc85d56ff25c0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 13 Nov 2025 07:19:53 +1000 Subject: [PATCH 082/231] mm/mp: always unload re-used but modified models (#10724) The partial unloader path in model re-use flow skips straight to the actual unload without any check of the patching UUID. This means that if you do an upscale flow with a model patch on an existing model, it will not apply your patchings. Fix by delaying the partial_unload until after the uuid checks. This is done by making partial_unload a model of partial_load where extra_mem is -ve. --- comfy/model_management.py | 5 +---- comfy/model_patcher.py | 3 +++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8913082a..a21df54b3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,10 +503,7 @@ class LoadedModel: use_more_vram = lowvram_model_memory if use_more_vram == 0: use_more_vram = 1e32 - if use_more_vram > 0: - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) - else: - self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights) + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) real_model = self.model.model diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 68b0a9192..cf1b0d441 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -928,6 +928,9 @@ class ModelPatcher: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) + if extra_memory < 0 and not unpatch_weights: + self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: self.apply_hooks(self.forced_hooks, force_apply=True) From 1c7eaeca1013e4315f36e0d4d274faa106001121 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 13 Nov 2025 07:20:53 +1000 Subject: [PATCH 083/231] qwen: reduce VRAM usage (#10725) Clean up a bunch of stacked and no-longer-needed tensors on the QWEN VRAM peak (currently FFN). With this I go from OOMing at B=37x1328x1328 to being able to succesfully run B=47 (RTX5090). --- comfy/ldm/qwen_image/model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index e5d0d17c1..427ea19c1 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -236,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module): img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - img_normed = self.img_norm1(hidden_states) - img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) - txt_normed = self.txt_norm1(encoder_hidden_states) - txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + del img_mod1 + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + del txt_mod1 img_attn_output, txt_attn_output = self.attn( hidden_states=img_modulated, @@ -248,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module): image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) + del img_modulated + del txt_modulated hidden_states = hidden_states + img_gate1 * img_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + del img_attn_output + del txt_attn_output + del img_gate1 + del txt_gate1 - img_normed2 = self.img_norm2(hidden_states) - img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) - txt_normed2 = self.txt_norm2(encoder_hidden_states) - txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) return encoder_hidden_states, hidden_states From 8b0b93df51d04f08eb779cb84dc331fa18b43ae8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:04:41 -0800 Subject: [PATCH 084/231] Update Python 3.14 compatibility notes in README (#10730) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e28803a2..f51807ad5 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended. +Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 From 3b3ef9a77ac03ed516a45063f9736f33085cecca Mon Sep 17 00:00:00 2001 From: contentis Date: Thu, 13 Nov 2025 00:26:52 +0100 Subject: [PATCH 085/231] Quantized Ops fixes (#10715) * offload support, bug fixes, remove mixins * add readme --- QUANTIZATION.md | 168 +++++++++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 37 ++++------ comfy/quant_ops.py | 39 ++++++++++- 3 files changed, 219 insertions(+), 25 deletions(-) create mode 100644 QUANTIZATION.md diff --git a/QUANTIZATION.md b/QUANTIZATION.md new file mode 100644 index 000000000..1693e13f3 --- /dev/null +++ b/QUANTIZATION.md @@ -0,0 +1,168 @@ +# The Comfy guide to Quantization + + +## How does quantization work? + +Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware. + +When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues: +- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values +- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused" + +By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling. + +``` +absmax = max(abs(tensor)) +scale = amax / max_dynamic_range_low_precision + +# Quantization +tensor_q = (tensor / scale).to(low_precision_dtype) + +# De-Quantization +tensor_dq = tensor_q.to(fp16) * scale + +tensor_dq ~ tensor +``` + +Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes. + + +## Quantization in Comfy + +``` +QuantizedTensor (torch.Tensor subclass) + ↓ __torch_dispatch__ +Two-Level Registry (generic + layout handlers) + ↓ +MixedPrecisionOps + Metadata Detection +``` + +### Representation + +To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py` + +A `Layout` class defines how a specific quantization format behaves: +- Required parameters +- Quantize method +- De-Quantize method + +```python +from comfy.quant_ops import QuantizedLayout + +class MyLayout(QuantizedLayout): + @classmethod + def quantize(cls, tensor, **kwargs): + # Convert to quantized format + qdata = ... + params = {'scale': ..., 'orig_dtype': tensor.dtype} + return qdata, params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + return qdata.to(orig_dtype) * scale +``` + +To then run operations using these QuantizedTensors we use two registry systems to define supported operations. +The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`). + +The second registry is layout-specific and allows to implement fast-paths like nn.Linear. +```python +from comfy.quant_ops import register_layout_op + +@register_layout_op(torch.ops.aten.linear.default, MyLayout) +def my_linear(func, args, kwargs): + # Extract tensors, call optimized kernel + ... +``` +When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation. +For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation. + + +### Mixed Precision + +The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how. + +**Architecture:** + +```python +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} # Maps layer names to quantization configs + _compute_dtype = torch.bfloat16 # Default compute / dequantize precision +``` + +**Key mechanism:** + +The custom `Linear._load_from_state_dict()` method inspects each layer during model loading: +- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype` +- If the layer name **is** in `_layer_quant_config`: + - Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`) + - Load associated quantization parameters (scales, block_size, etc.) + +**Why it's needed:** + +Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality. + +The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode. + + +## Checkpoint Format + +Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme. + +The quantized checkpoint will contain the same layers as the original checkpoint but: +- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8. +- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe. +- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used. + +### Scaling Parameters details +We define 4 possible scaling parameters that should cover most recipes in the near-future: +- **weight_scale**: quantization scalers for the weights +- **weight_scale_2**: global scalers in the context of double scaling +- **pre_quant_scale**: scalers used for smoothing salient weights +- **input_scale**: quantization scalers for the activations + +| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | +|--------|---------------|--------------|----------------|-----------------|-------------| +| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | + +You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). + +### Quantization Metadata + +The metadata stored alongside the checkpoint contains: +- **format_version**: String to define a version of the standard +- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. + +Example: +```json +{ + "_quantization_metadata": { + "format_version": "1.0", + "layers": { + "model.layers.0.mlp.up_proj": "float8_e4m3fn", + "model.layers.0.mlp.down_proj": "float8_e4m3fn", + "model.layers.1.mlp.up_proj": "float8_e4m3fn" + } + } +} +``` + + +## Creating Quantized Checkpoints + +To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`. + +### Weight Quantization + +Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter. + +### Calibration (for Activation Quantization) + +Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**: + +1. **Collect statistics**: Run inference on N representative samples +2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer +3. **Compute scales**: Derive `input_scale` from collected statistics +4. **Store in checkpoint**: Save `input_scale` parameters alongside weights + +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file diff --git a/comfy/ops.py b/comfy/ops.py index 96dffa85d..2a90a5ba2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -77,7 +77,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: - dtype = input.dtype + if isinstance(input, QuantizedTensor): + dtype = input._layout_params["orig_dtype"] + else: + dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: @@ -534,18 +537,7 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor - -QUANT_FORMAT_MIXINS = { - "float8_e4m3fn": { - "dtype": torch.float8_e4m3fn, - "layout_type": "TensorCoreFP8Layout", - "parameters": { - "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), - "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), - } - } -} +from .quant_ops import QuantizedTensor, QUANT_ALGOS class MixedPrecisionOps(disable_weight_init): _layer_quant_config = {} @@ -596,23 +588,24 @@ class MixedPrecisionOps(disable_weight_init): if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - mixin = QUANT_FORMAT_MIXINS[quant_format] - self.layout_type = mixin["layout_type"] + qconfig = QUANT_ALGOS[quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] - scale_key = f"{prefix}weight_scale" + weight_scale_key = f"{prefix}weight_scale" layout_params = { - 'scale': state_dict.pop(scale_key, None), - 'orig_dtype': MixedPrecisionOps._compute_dtype + 'scale': state_dict.pop(weight_scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), } if layout_params['scale'] is not None: - manually_loaded_keys.append(scale_key) + manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), requires_grad=False ) - for param_name, param_value in mixin["parameters"].items(): + for param_name in qconfig["parameters"]: param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -643,7 +636,7 @@ class MixedPrecisionOps(disable_weight_init): if (getattr(self, 'layout_type', None) is not None and getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c56e32a73..1d058bece 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -74,6 +74,12 @@ def _copy_layout_params(params): new_params[k] = v return new_params +def _copy_layout_params_inplace(src, dst, non_blocking=False): + for k, v in src.items(): + if isinstance(v, torch.Tensor): + dst[k].copy_(v, non_blocking=non_blocking) + else: + dst[k] = v class QuantizedLayout: """ @@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs): def generic_copy_(func, args, kwargs): qt_dest = args[0] src = args[1] - + non_blocking = args[2] if len(args) > 2 else False if isinstance(qt_dest, QuantizedTensor): if isinstance(src, QuantizedTensor): # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata) + qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) qt_dest._layout_type = src._layout_type - qt_dest._layout_params = _copy_layout_params(src._layout_params) + _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) else: # Copy from regular tensor - just copy raw data qt_dest._qdata.copy_(src) @@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs): def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True + +@register_generic_util(torch.ops.aten.empty_like.default) +def generic_empty_like(func, args, kwargs): + """Empty_like operation - creates an empty tensor with the same quantized structure.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + # Create empty tensor with same shape and dtype as the quantized data + hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) + new_qdata = torch.empty_like(qt._qdata, **kwargs) + + # Handle device transfer for layout params + target_device = kwargs.get('device', new_qdata.device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + + # Update orig_dtype if dtype is specified + new_params['orig_dtype'] = hp_dtype + + return QuantizedTensor(new_qdata, qt._layout_type, new_params) + return func(*args, **kwargs) + # ============================================================================== # FP8 Layout + Operation Handlers # ============================================================================== @@ -378,6 +404,13 @@ class TensorCoreFP8Layout(QuantizedLayout): def get_plain_tensors(cls, qtensor): return qtensor._qdata, qtensor._layout_params['scale'] +QUANT_ALGOS = { + "float8_e4m3fn": { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8Layout", + }, +} LAYOUTS = { "TensorCoreFP8Layout": TensorCoreFP8Layout, From f91078b1ffa484c424f78814f54de4d5846e4daa Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:05:26 +0200 Subject: [PATCH 086/231] add PR template for API-Nodes (#10736) --- .github/PULL_REQUEST_TEMPLATE/api-node.md | 21 ++++++++ .github/workflows/api-node-template.yml | 58 +++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/api-node.md create mode 100644 .github/workflows/api-node-template.yml diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md new file mode 100644 index 000000000..f62744878 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md @@ -0,0 +1,21 @@ + + +## API Node PR Checklist + +### Scope +- [ ] **Is API Node Change** + +### Pricing & Billing +- [ ] **Need pricing update** +- [ ] **No pricing update** + +If **Need pricing update**: +- [ ] Metronome rate cards updated +- [ ] Auto‑billing tests updated and passing + +### QA +- [ ] **QA done** +- [ ] **QA not required** + +### Comms +- [ ] Informed **@Kosinkadink** diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml new file mode 100644 index 000000000..0775f9979 --- /dev/null +++ b/.github/workflows/api-node-template.yml @@ -0,0 +1,58 @@ +name: Append API Node PR template + +on: + pull_request_target: + types: [opened, reopened, synchronize, edited, ready_for_review] + paths: + - 'comfy_api_nodes/**' # only run if these files changed + +permissions: + contents: read + pull-requests: write + +jobs: + inject: + runs-on: ubuntu-latest + steps: + - name: Ensure template exists and append to PR body + uses: actions/github-script@v7 + with: + script: | + const { owner, repo } = context.repo; + const number = context.payload.pull_request.number; + const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md'; + const marker = ''; + + const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number }); + + let templateText; + try { + const res = await github.rest.repos.getContent({ + owner, + repo, + path: templatePath, + ref: pr.base.ref + }); + const buf = Buffer.from(res.data.content, res.data.encoding || 'base64'); + templateText = buf.toString('utf8'); + } catch (e) { + core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`); + return; + } + + // Enforce the presence of the marker inside the template (for idempotence) + if (!templateText.includes(marker)) { + core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`); + return; + } + + // If the PR already contains the marker, do not append again. + const body = pr.body || ''; + if (body.includes(marker)) { + core.info('Template already present in PR body; nothing to inject.'); + return; + } + + const newBody = (body ? body + '\n\n' : '') + templateText + '\n'; + await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody }); + core.notice('API Node template appended to PR description.'); From 2fde9597f4b02c5f06c1a5ceb3ca2fa6d74966ec Mon Sep 17 00:00:00 2001 From: ric-yu Date: Thu, 13 Nov 2025 15:11:52 -0800 Subject: [PATCH 087/231] feat: add create_time dict to prompt field in /history and /queue (#10741) --- server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server.py b/server.py index 5d773b10a..d059d3dc9 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ import os import sys import asyncio import traceback +import time import nodes import folder_paths @@ -733,6 +734,7 @@ class PromptServer(): for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) + extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) From 94c298f9625b0fd9af8ea07a73075fdefe0d9e57 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 14 Nov 2025 10:02:03 +1000 Subject: [PATCH 088/231] flux: reduce VRAM usage (#10737) Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22 for 1600x1600 on RTX5090. --- comfy/ldm/flux/layers.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index a3eab0470..f4bf56e01 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -167,39 +167,55 @@ class DoubleStreamBlock(nn.Module): img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) if self.flipped_img_txt: + q = torch.cat((img_q, txt_q), dim=2) + del img_q, txt_q + k = torch.cat((img_k, txt_k), dim=2) + del img_k, txt_k + v = torch.cat((img_v, txt_v), dim=2) + del img_v, txt_v # run actual attention - attn = attention(torch.cat((img_q, txt_q), dim=2), - torch.cat((img_k, txt_k), dim=2), - torch.cat((img_v, txt_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + del img_attn img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + del txt_attn txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: @@ -249,12 +265,15 @@ class SingleStreamBlock(nn.Module): qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del qkv q, k = self.norm(q, k, v) # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) From 1ef328c007a419c2c429df0f80532cc11579dc97 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 13 Nov 2025 18:32:39 -0800 Subject: [PATCH 089/231] Better instructions for the portable. (#10743) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f51807ad5..cd8273b0d 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) -Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\ If you have trouble extracting it, right click the file -> properties -> unblock From f60923590c3f2fd05e166e2ec57968aaf7007dd0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:28:05 -0800 Subject: [PATCH 090/231] Use same code for chroma and flux blocks so that optimizations are shared. (#10746) --- comfy/ldm/chroma/layers.py | 121 ----------------------------- comfy/ldm/chroma/model.py | 7 +- comfy/ldm/chroma_radiance/model.py | 7 +- comfy/ldm/flux/layers.py | 31 ++++++-- 4 files changed, 31 insertions(+), 135 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index fc7110cce..9f4ad5bd2 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,12 +1,9 @@ import torch from torch import Tensor, nn -from comfy.ldm.flux.math import attention from comfy.ldm.flux.layers import ( MLPEmbedder, RMSNorm, - QKNorm, - SelfAttention, ModulationOut, ) @@ -48,124 +45,6 @@ class Approximator(nn.Module): return x -class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): - super().__init__() - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - - self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.flipped_img_txt = flipped_img_txt - - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): - (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec - - # prepare image for attention - img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) - - # prepare txt for attention - txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - - # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask, transformer_options=transformer_options) - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) - img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) - - # calculate the txt bloks - txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) - txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) - - if txt.dtype == torch.float16: - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) - - return img, txt - - -class SingleStreamBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - mlp_ratio: float = 4.0, - qk_scale: float = None, - dtype=None, - device=None, - operations=None - ): - super().__init__() - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) - # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - - self.hidden_size = hidden_size - self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - - self.mlp_act = nn.GELU(approximate="tanh") - - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: - mod = vec - x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k = self.norm(q, k, v) - - # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x.addcmul_(mod.gate, output) - if x.dtype == torch.float16: - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) - return x - - class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index ad1c523fe..67bf70eb1 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -11,12 +11,12 @@ import comfy.ldm.common_dit from comfy.ldm.flux.layers import ( EmbedND, timestep_embedding, + DoubleStreamBlock, + SingleStreamBlock, ) from .layers import ( - DoubleStreamBlock, LastLayer, - SingleStreamBlock, Approximator, ChromaModulationOut, ) @@ -90,6 +90,7 @@ class Chroma(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -98,7 +99,7 @@ class Chroma(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 7d7be80f5..e643b4414 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -10,12 +10,10 @@ from torch import Tensor, nn from einops import repeat import comfy.ldm.common_dit -from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock from comfy.ldm.chroma.model import Chroma, ChromaParams from comfy.ldm.chroma.layers import ( - DoubleStreamBlock, - SingleStreamBlock, Approximator, ) from .layers import ( @@ -89,7 +87,6 @@ class ChromaRadiance(Chroma): dtype=dtype, device=device, operations=operations ) - self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -97,6 +94,7 @@ class ChromaRadiance(Chroma): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -109,6 +107,7 @@ class ChromaRadiance(Chroma): self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, + modulation=False, dtype=dtype, device=device, operations=operations, ) for _ in range(params.depth_single_blocks) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index f4bf56e01..23150a712 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.modulation = modulation + + if self.modulation: + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) @@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + if self.modulation: + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) @@ -160,8 +166,11 @@ class DoubleStreamBlock(nn.Module): self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) + if self.modulation: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + else: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention img_modulated = self.img_norm1(img) @@ -236,6 +245,7 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None, + modulation=True, dtype=None, device=None, operations=None @@ -258,10 +268,17 @@ class SingleStreamBlock(nn.Module): self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + if modulation: + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + else: + self.modulation = None def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: - mod, _ = self.modulation(vec) + if self.modulation: + mod, _ = self.modulation(vec) + else: + mod = vec + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) From 443056c401c53953bb8eee6da71b9ad29afe2581 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:26:05 -0800 Subject: [PATCH 091/231] Fix custom nodes import error. (#10747) This should fix the import errors but will break if the custom nodes actually try to use the class. --- comfy/ldm/chroma/layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 9f4ad5bd2..2d5684348 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -7,6 +7,9 @@ from comfy.ldm.flux.layers import ( ModulationOut, ) +# TODO: remove this in a few months +SingleStreamBlock = None +DoubleStreamBlock = None class ChromaModulationOut(ModulationOut): From bd01d9f7fd241a45bd08b60dfedbe78577383cc4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 15 Nov 2025 03:54:40 -0800 Subject: [PATCH 092/231] Add left padding support to tokenizers. (#10753) --- comfy/sd1_clip.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index f8a7c2a1b..3066de2d7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -468,6 +468,7 @@ class SDTokenizer: self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding + self.pad_left = pad_left empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -522,6 +523,12 @@ class SDTokenizer: return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) return (embed, leftover) + def pad_tokens(self, tokens, amount): + if self.pad_left: + for i in range(amount): + tokens.insert(0, (self.pad_token, 1.0, 0)) + else: + tokens.extend([(self.pad_token, 1.0, 0)] * amount) def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' @@ -600,7 +607,7 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) + self.pad_tokens(batch, remaining_length) #start new batch batch = [] if self.start_token is not None: @@ -614,11 +621,11 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if min_padding is not None: - batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + self.pad_tokens(batch, min_padding) if self.pad_to_max_length and len(batch) < self.max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) + self.pad_tokens(batch, self.max_length - len(batch)) if min_length is not None and len(batch) < min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) + self.pad_tokens(batch, min_length - len(batch)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] From 9a0238256873711bd38ce0e0b1d15a617a1ee454 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 15 Nov 2025 21:18:49 +0200 Subject: [PATCH 093/231] chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757) --- comfy_api_nodes/nodes_openai.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index acf35d276..e08bec08c 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -169,6 +169,7 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -299,6 +300,7 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod From 2d4a08b717c492fa45e98bd70beb48d4e77cb464 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 15 Nov 2025 22:37:34 +0200 Subject: [PATCH 094/231] Revert "chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)" (#10759) This reverts commit 9a0238256873711bd38ce0e0b1d15a617a1ee454. --- comfy_api_nodes/nodes_openai.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index e08bec08c..acf35d276 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -169,7 +169,6 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - is_deprecated=True, ) @classmethod @@ -300,7 +299,6 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - is_deprecated=True, ) @classmethod From 7d6103325e1c97aa54f963253e3e7f1d6da6947f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 16 Nov 2025 00:01:14 -0800 Subject: [PATCH 095/231] Change ROCm nightly install command to 7.1 (#10764) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd8273b0d..c0384099d 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` ### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. From 3d0003c24c1aec9f0c021dbc70ffb7cd8cf0685c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Nov 2025 17:17:24 -0500 Subject: [PATCH 096/231] ComfyUI version 0.3.69 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 25d1a4157..1e554eb9f 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.68" +__version__ = "0.3.69" diff --git a/pyproject.toml b/pyproject.toml index 79ff3f74a..63778286f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.68" +version = "0.3.69" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 27cbac865ec226cfd9c1563327b0d62cf5dbd484 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 17 Nov 2025 16:04:04 -0800 Subject: [PATCH 097/231] Add release workflow for NVIDIA cu126 (#10777) --- .github/workflows/release-stable-all.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 7dca7277b..f7de3a7c3 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -43,6 +43,23 @@ jobs: test_release: true secrets: inherit + release_nvidia_cu126: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA cu126" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu126" + python_minor: "12" + python_patch: "10" + rel_name: "nvidia" + rel_extra_name: "_cu126" + test_release: true + secrets: inherit + release_amd_rocm: permissions: contents: "write" From f41e5f398d5d4059a3c87cf157bd932afcce3c0d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 17 Nov 2025 16:59:19 -0800 Subject: [PATCH 098/231] Update README with new portable download link (#10778) --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c0384099d..323dfc587 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,9 @@ Update your Nvidia drivers if it doesn't start. [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). +[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). + +[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From fdf49a28617f742d746ad209e57ed7420b3535dc Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 18 Nov 2025 11:04:06 +0800 Subject: [PATCH 099/231] Fix the portable download link for CUDA 12.6 (#10780) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 323dfc587..28beec427 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,7 @@ Update your Nvidia drivers if it doesn't start. [Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). -[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). +[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From 47bfd5a33fa984a1102fc2bd7b25c91a69ace288 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 17 Nov 2025 21:26:44 -0800 Subject: [PATCH 100/231] Native block swap custom nodes considered harmful. (#10783) --- comfy_extras/nodes_nop.py | 39 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 40 insertions(+) create mode 100644 comfy_extras/nodes_nop.py diff --git a/comfy_extras/nodes_nop.py b/comfy_extras/nodes_nop.py new file mode 100644 index 000000000..953061bcb --- /dev/null +++ b/comfy_extras/nodes_nop.py @@ -0,0 +1,39 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override +# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list + +# "native" block swap nodes are placebo at best and break the ComfyUI memory management system. +# They are also considered harmful because instead of users reporting issues with the built in +# memory management they install these stupid nodes and complain even harder. Now it completely +# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it +# out of all workflows. +class wanBlockSwap(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="wanBlockSwap", + category="", + description="NOP", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(), + ], + is_deprecated=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: + return io.NodeOutput(model) + + +class NopExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + wanBlockSwap + ] + +async def comfy_entrypoint() -> NopExtension: + return NopExtension() diff --git a/nodes.py b/nodes.py index 5689f6fe1..f6aeedc78 100644 --- a/nodes.py +++ b/nodes.py @@ -2330,6 +2330,7 @@ async def init_builtin_extra_nodes(): "nodes_easycache.py", "nodes_audio_encoder.py", "nodes_rope.py", + "nodes_nop.py", ] import_failed = [] From 048f49adbd19ac2d9c7c87682c832b7827a4b29d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 18 Nov 2025 13:59:27 +0200 Subject: [PATCH 101/231] chore(api-nodes): adjusted PR template; set min python version for pylint to 3.10 (#10787) --- .github/PULL_REQUEST_TEMPLATE/api-node.md | 2 +- .github/workflows/api-node-template.yml | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md index f62744878..c1f1bafb1 100644 --- a/.github/PULL_REQUEST_TEMPLATE/api-node.md +++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md @@ -18,4 +18,4 @@ If **Need pricing update**: - [ ] **QA not required** ### Comms -- [ ] Informed **@Kosinkadink** +- [ ] Informed **Kosinkadink** diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml index 0775f9979..fdb81c0c5 100644 --- a/.github/workflows/api-node-template.yml +++ b/.github/workflows/api-node-template.yml @@ -2,7 +2,7 @@ name: Append API Node PR template on: pull_request_target: - types: [opened, reopened, synchronize, edited, ready_for_review] + types: [opened, reopened, synchronize, ready_for_review] paths: - 'comfy_api_nodes/**' # only run if these files changed diff --git a/pyproject.toml b/pyproject.toml index 63778286f..a14b383b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ lint.select = [ exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.pylint] -master.py-version = "3.9" +master.py-version = "3.10" master.extension-pkg-allow-list = [ "pydantic", ] From e1ab6bb394b82fa654d5bc84043f97479d12f84c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:00:21 +0200 Subject: [PATCH 102/231] EasyCache: Fix for mismatch in input/output channels with some models (#10788) Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V Also fix for slicing deprecation in pytorch 2.9 --- comfy_extras/nodes_easycache.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 1359e2f99..11b23ffdb 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -11,13 +11,13 @@ if TYPE_CHECKING: def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args - x: torch.Tensor = args[0] transformer_options: dict[str] = args[-1] if not isinstance(transformer_options, dict): transformer_options = kwargs.get("transformer_options") if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] + x: torch.Tensor = args[0][:, :easycache.output_channels] sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs): def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args - x: torch.Tensor = args[0] timestep: float = args[1] model_options: dict[str] = args[2] easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) # prepare next x_prev + x: torch.Tensor = args[0][:, :easycache.output_channels] next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs): class EasyCacheHolder: - def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): self.name = "EasyCache" self.reuse_threshold = reuse_threshold self.start_percent = start_percent @@ -202,6 +202,7 @@ class EasyCacheHolder: self.allow_mismatch = True self.cut_from_start = True self.state_metadata = None + self.output_channels = output_channels def is_past_end_timestep(self, timestep: float) -> bool: return not (timestep[0] > self.end_t).item() @@ -264,7 +265,7 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) return x def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): @@ -283,7 +284,7 @@ class EasyCacheHolder: else: slicing.append(slice(None)) skip_dim = False - x = x[slicing] + x = x[tuple(slicing)] diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): @@ -323,7 +324,7 @@ class EasyCacheHolder: return self def clone(self): - return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) class EasyCacheNode(io.ComfyNode): @@ -350,7 +351,7 @@ class EasyCacheNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: model = model.clone() - model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) @@ -358,7 +359,7 @@ class EasyCacheNode(io.ComfyNode): class LazyCacheHolder: - def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): self.name = "LazyCache" self.reuse_threshold = reuse_threshold self.start_percent = start_percent @@ -382,6 +383,7 @@ class LazyCacheHolder: self.approx_output_change_rates = [] self.total_steps_skipped = 0 self.state_metadata = None + self.output_channels = output_channels def has_cache_diff(self) -> bool: return self.cache_diff is not None @@ -456,7 +458,7 @@ class LazyCacheHolder: return self def clone(self): - return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) class LazyCacheNode(io.ComfyNode): @classmethod @@ -482,7 +484,7 @@ class LazyCacheNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: model = model.clone() - model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) return io.NodeOutput(model) From d52697457608a045cafc3b6d6cb89f0a49ba0709 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 18 Nov 2025 13:46:19 -0800 Subject: [PATCH 103/231] Fix hunyuan 3d 2.0 (#10792) --- comfy/ldm/flux/math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 158420290..6a22df8bc 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -7,7 +7,8 @@ import comfy.model_management def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: - q, k = apply_rope(q, k, pe) + if pe is not None: + q, k = apply_rope(q, k, pe) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x From 24fdb92edf2e96fe757c480aa7f12be5bdfa3a15 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:26:44 +0200 Subject: [PATCH 104/231] feat(api-nodes): add new Gemini model (#10789) --- comfy_api_nodes/apis/gemini_api.py | 231 +++++++++++++++++++++++++++-- comfy_api_nodes/nodes_gemini.py | 47 +++--- 2 files changed, 246 insertions(+), 32 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 2bf28bf93..f63e02693 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -1,22 +1,229 @@ -from typing import Optional +from datetime import date +from enum import Enum +from typing import Any -from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +class GeminiSafetyCategory(str, Enum): + HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" + HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" + HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" + HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" + + +class GeminiSafetyThreshold(str, Enum): + OFF = "OFF" + BLOCK_NONE = "BLOCK_NONE" + BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" + BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" + BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" + + +class GeminiSafetySetting(BaseModel): + category: GeminiSafetyCategory + threshold: GeminiSafetyThreshold + + +class GeminiRole(str, Enum): + user = "user" + model = "model" + + +class GeminiMimeType(str, Enum): + application_pdf = "application/pdf" + audio_mpeg = "audio/mpeg" + audio_mp3 = "audio/mp3" + audio_wav = "audio/wav" + image_png = "image/png" + image_jpeg = "image/jpeg" + image_webp = "image/webp" + text_plain = "text/plain" + video_mov = "video/mov" + video_mpeg = "video/mpeg" + video_mp4 = "video/mp4" + video_mpg = "video/mpg" + video_avi = "video/avi" + video_wmv = "video/wmv" + video_mpegps = "video/mpegps" + video_flv = "video/flv" + + +class GeminiInlineData(BaseModel): + data: str | None = Field( + None, + description="The base64 encoding of the image, PDF, or video to include inline in the prompt. " + "When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB", + ) + mimeType: GeminiMimeType | None = Field(None) + + +class GeminiPart(BaseModel): + inlineData: GeminiInlineData | None = Field(None) + text: str | None = Field(None) + + +class GeminiTextPart(BaseModel): + text: str | None = Field(None) + + +class GeminiContent(BaseModel): + parts: list[GeminiPart] = Field(...) + role: GeminiRole = Field(..., examples=["user"]) + + +class GeminiSystemInstructionContent(BaseModel): + parts: list[GeminiTextPart] = Field( + ..., + description="A list of ordered parts that make up a single message. " + "Different parts may have different IANA MIME types.", + ) + role: GeminiRole = Field( + ..., + description="The identity of the entity that creates the message. " + "The following values are supported: " + "user: This indicates that the message is sent by a real person, typically a user-generated message. " + "model: This indicates that the message is generated by the model. " + "The model value is used to insert messages from model into the conversation during multi-turn conversations. " + "For non-multi-turn conversations, this field can be left blank or unset.", + ) + + +class GeminiFunctionDeclaration(BaseModel): + description: str | None = Field(None) + name: str = Field(...) + parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters") + + +class GeminiTool(BaseModel): + functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None) + + +class GeminiOffset(BaseModel): + nanos: int | None = Field(None, ge=0, le=999999999) + seconds: int | None = Field(None, ge=-315576000000, le=315576000000) + + +class GeminiVideoMetadata(BaseModel): + endOffset: GeminiOffset | None = Field(None) + startOffset: GeminiOffset | None = Field(None) + + +class GeminiGenerationConfig(BaseModel): + maxOutputTokens: int | None = Field(None, ge=16, le=8192) + seed: int | None = Field(None) + stopSequences: list[str] | None = Field(None) + temperature: float | None = Field(1, ge=0.0, le=2.0) + topK: int | None = Field(40, ge=1) + topP: float | None = Field(0.95, ge=0.0, le=1.0) class GeminiImageConfig(BaseModel): - aspectRatio: Optional[str] = None + aspectRatio: str | None = Field(None) + resolution: str | None = Field(None) class GeminiImageGenerationConfig(GeminiGenerationConfig): - responseModalities: Optional[list[str]] = None - imageConfig: Optional[GeminiImageConfig] = None + responseModalities: list[str] | None = Field(None) + imageConfig: GeminiImageConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): - contents: list[GeminiContent] - generationConfig: Optional[GeminiImageGenerationConfig] = None - safetySettings: Optional[list[GeminiSafetySetting]] = None - systemInstruction: Optional[GeminiSystemInstructionContent] = None - tools: Optional[list[GeminiTool]] = None - videoMetadata: Optional[GeminiVideoMetadata] = None + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiImageGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class GeminiGenerateContentRequest(BaseModel): + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class Modality(str, Enum): + MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED" + TEXT = "TEXT" + IMAGE = "IMAGE" + VIDEO = "VIDEO" + AUDIO = "AUDIO" + DOCUMENT = "DOCUMENT" + + +class ModalityTokenCount(BaseModel): + modality: Modality | None = None + tokenCount: int | None = Field(None, description="Number of tokens for the given modality.") + + +class Probability(str, Enum): + NEGLIGIBLE = "NEGLIGIBLE" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + UNKNOWN = "UNKNOWN" + + +class GeminiSafetyRating(BaseModel): + category: GeminiSafetyCategory | None = None + probability: Probability | None = Field( + None, + description="The probability that the content violates the specified safety category", + ) + + +class GeminiCitation(BaseModel): + authors: list[str] | None = None + endIndex: int | None = None + license: str | None = None + publicationDate: date | None = None + startIndex: int | None = None + title: str | None = None + uri: str | None = None + + +class GeminiCitationMetadata(BaseModel): + citations: list[GeminiCitation] | None = None + + +class GeminiCandidate(BaseModel): + citationMetadata: GeminiCitationMetadata | None = None + content: GeminiContent | None = None + finishReason: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiPromptFeedback(BaseModel): + blockReason: str | None = None + blockReasonMessage: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiUsageMetadata(BaseModel): + cachedContentTokenCount: int | None = Field( + None, + description="Output only. Number of tokens in the cached part in the input (the cached content).", + ) + candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).") + candidatesTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of candidate tokens by modality." + ) + promptTokenCount: int | None = Field( + None, + description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.", + ) + promptTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of prompt tokens by modality." + ) + thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.") + toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).") + + +class GeminiGenerateContentResponse(BaseModel): + candidates: list[GeminiCandidate] | None = Field(None) + promptFeedback: GeminiPromptFeedback | None = Field(None) + usageMetadata: GeminiUsageMetadata | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 67f2469ad..6e746eebd 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -3,8 +3,6 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ -from __future__ import annotations - import base64 import json import os @@ -12,7 +10,7 @@ import time import uuid from enum import Enum from io import BytesIO -from typing import Literal, Optional +from typing import Literal import torch from typing_extensions import override @@ -20,18 +18,17 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.util import VideoCodec, VideoContainer -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.gemini_api import ( GeminiContent, GeminiGenerateContentRequest, GeminiGenerateContentResponse, - GeminiInlineData, - GeminiMimeType, - GeminiPart, -) -from comfy_api_nodes.apis.gemini_api import ( GeminiImageConfig, GeminiImageGenerateContentRequest, GeminiImageGenerationConfig, + GeminiInlineData, + GeminiMimeType, + GeminiPart, + GeminiRole, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -57,6 +54,7 @@ class GeminiModel(str, Enum): gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" gemini_2_5_pro = "gemini-2.5-pro" gemini_2_5_flash = "gemini-2.5-flash" + gemini_3_0_pro = "gemini-3-pro-preview" class GeminiImageModel(str, Enum): @@ -103,6 +101,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera Returns: List of response parts matching the requested type. """ + if response.candidates is None: + if response.promptFeedback.blockReason: + feedback = response.promptFeedback + raise ValueError( + f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" + ) + raise NotImplementedError( + "Gemini returned no response candidates. " + "Please report to ComfyUI repository with the example of workflow to reproduce this." + ) parts = [] for part in response.candidates[0].content.parts: if part_type == "text" and hasattr(part, "text") and part.text: @@ -272,10 +280,10 @@ class GeminiNode(IO.ComfyNode): prompt: str, model: str, seed: int, - images: Optional[torch.Tensor] = None, - audio: Optional[Input.Audio] = None, - video: Optional[Input.Video] = None, - files: Optional[list[GeminiPart]] = None, + images: torch.Tensor | None = None, + audio: Input.Audio | None = None, + video: Input.Video | None = None, + files: list[GeminiPart] | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -300,7 +308,7 @@ class GeminiNode(IO.ComfyNode): data=GeminiGenerateContentRequest( contents=[ GeminiContent( - role="user", + role=GeminiRole.user, parts=parts, ) ] @@ -308,7 +316,6 @@ class GeminiNode(IO.ComfyNode): response_model=GeminiGenerateContentResponse, ) - # Get result output output_text = get_text_from_response(response) if output_text: # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. @@ -406,7 +413,7 @@ class GeminiInputFiles(IO.ComfyNode): ) @classmethod - def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput: + def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput: """Loads and formats input files for Gemini API.""" if GEMINI_INPUT_FILES is None: GEMINI_INPUT_FILES = [] @@ -421,7 +428,7 @@ class GeminiImage(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="GeminiImageNode", - display_name="Google Gemini Image", + display_name="Nano Banana (Google Gemini Image)", category="api node/image/Gemini", description="Edit images synchronously via Google API.", inputs=[ @@ -488,8 +495,8 @@ class GeminiImage(IO.ComfyNode): prompt: str, model: str, seed: int, - images: Optional[torch.Tensor] = None, - files: Optional[list[GeminiPart]] = None, + images: torch.Tensor | None = None, + files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -510,7 +517,7 @@ class GeminiImage(IO.ComfyNode): endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ - GeminiContent(role="user", parts=parts), + GeminiContent(role=GeminiRole.user, parts=parts), ], generationConfig=GeminiImageGenerationConfig( responseModalities=["TEXT", "IMAGE"], From b5c8be8b1db44ded07cb1b437b9f33ebff5848c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Nov 2025 19:37:20 -0500 Subject: [PATCH 105/231] ComfyUI 0.3.70 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 1e554eb9f..9b77aabe9 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.69" +__version__ = "0.3.70" diff --git a/pyproject.toml b/pyproject.toml index a14b383b3..289b7145b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.69" +version = "0.3.70" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 17027f2a6a20a31e2c6f3be2b1a06f39ad3a68d9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 18 Nov 2025 19:36:03 -0800 Subject: [PATCH 106/231] Add a way to disable the final norm in the llama based TE models. (#10794) --- comfy/text_encoders/llama.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index c050759fe..feb44bbb0 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -32,6 +32,7 @@ class Llama2Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Qwen25_3BConfig: @@ -53,6 +54,7 @@ class Qwen25_3BConfig: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Qwen25_7BVLI_Config: @@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass class Gemma2_2B_Config: @@ -96,6 +99,7 @@ class Gemma2_2B_Config: k_norm = None sliding_attention = None rope_scale = None + final_norm: bool = True @dataclass class Gemma3_4B_Config: @@ -118,6 +122,7 @@ class Gemma3_4B_Config: k_norm = "gemma3" sliding_attention = [False, False, False, False, False, 1024] rope_scale = [1.0, 8.0] + final_norm: bool = True class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -366,7 +371,12 @@ class Llama2_(nn.Module): transformer(config, index=i, device=device, dtype=dtype, ops=ops) for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.norm = None + # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): @@ -421,14 +431,16 @@ class Llama2_(nn.Module): if i == intermediate_output: intermediate = x.clone() - x = self.norm(x) + if self.norm is not None: + x = self.norm(x) + if all_intermediate is not None: all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) - if intermediate is not None and final_layer_norm_intermediate: + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) return x, intermediate From 65ee24c9789b93660ebe978a3186486f105298c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 19 Nov 2025 11:25:28 +0200 Subject: [PATCH 107/231] change display name of PreviewAny node to "Preview as Text" (#10796) --- comfy_extras/nodes_preview_any.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index e749fa6ae..139b07c93 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -39,5 +39,5 @@ NODE_CLASS_MAPPINGS = { } NODE_DISPLAY_NAME_MAPPINGS = { - "PreviewAny": "Preview Any", + "PreviewAny": "Preview as Text", } From 6a1d3a1ae131f3fff7f45a7e835eb10e9d1338ee Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 20 Nov 2025 00:49:01 +0200 Subject: [PATCH 108/231] convert hunyuan3d.py to V3 schema (#10664) --- comfy_api/latest/__init__.py | 4 +- comfy_api/latest/_io.py | 5 +- comfy_api/latest/_util/__init__.py | 3 + comfy_api/latest/_util/geometry_types.py | 12 + comfy_extras/nodes_hunyuan3d.py | 274 +++++++++++++---------- 5 files changed, 178 insertions(+), 120 deletions(-) create mode 100644 comfy_api/latest/_util/geometry_types.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b7a3fa9c1..176ae36e0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents -from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents +from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io as io from . import _ui as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 @@ -104,6 +104,8 @@ class Types: VideoCodec = VideoCodec VideoContainer = VideoContainer VideoComponents = VideoComponents + MESH = MESH + VOXEL = VOXEL ComfyAPI = ComfyAPI_latest diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 0b701260f..863254ce7 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr prune_dict, shallow_clone_class) from comfy_api.latest._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker +from ._util import MESH, VOXEL # from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference @@ -656,11 +657,11 @@ class LossMap(ComfyTypeIO): @comfytype(io_type="VOXEL") class Voxel(ComfyTypeIO): - Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + Type = VOXEL @comfytype(io_type="MESH") class Mesh(ComfyTypeIO): - Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + Type = MESH @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 9019c46db..fc5431dda 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,8 +1,11 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents +from .geometry_types import VOXEL, MESH __all__ = [ # Utility Types "VideoContainer", "VideoCodec", "VideoComponents", + "VOXEL", + "MESH", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py new file mode 100644 index 000000000..385122778 --- /dev/null +++ b/comfy_api/latest/_util/geometry_types.py @@ -0,0 +1,12 @@ +import torch + + +class VOXEL: + def __init__(self, data: torch.Tensor): + self.data = data + + +class MESH: + def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): + self.vertices = vertices + self.faces = faces diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index f6e71e0a8..adca14f62 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -7,63 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro import folder_paths import comfy.model_management from comfy.cli_args import args +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa -class EmptyLatentHunyuan3Dv2: + +class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentHunyuan3Dv2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + ], + outputs=[ + IO.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/3d" - - def generate(self, resolution, batch_size): + @classmethod + def execute(cls, resolution, batch_size) -> IO.NodeOutput: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) - return ({"samples": latent, "type": "hunyuan3dv2"}, ) + return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"}) -class Hunyuan3Dv2Conditioning: + generate = execute # TODO: remove + + +class Hunyuan3Dv2Conditioning(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), - }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("clip_vision_output"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, clip_vision_output): + @classmethod + def execute(cls, clip_vision_output) -> IO.NodeOutput: embeds = clip_vision_output.last_hidden_state positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class Hunyuan3Dv2ConditioningMultiView: +class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {}, - "optional": {"front": ("CLIP_VISION_OUTPUT",), - "left": ("CLIP_VISION_OUTPUT",), - "back": ("CLIP_VISION_OUTPUT",), - "right": ("CLIP_VISION_OUTPUT",), }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2ConditioningMultiView", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("front", optional=True), + IO.ClipVisionOutput.Input("left", optional=True), + IO.ClipVisionOutput.Input("back", optional=True), + IO.ClipVisionOutput.Input("right", optional=True), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, front=None, left=None, back=None, right=None): + @classmethod + def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput: all_embeds = [front, left, back, right] out = [] pos_embeds = None @@ -76,29 +92,35 @@ class Hunyuan3Dv2ConditioningMultiView: embeds = torch.cat(out, dim=1) positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class VOXEL: - def __init__(self, data): - self.data = data - -class VAEDecodeHunyuan3D: +class VAEDecodeHunyuan3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), - "vae": ("VAE", ), - "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), - "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), - }} - RETURN_TYPES = ("VOXEL",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeHunyuan3D", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("num_chunks", default=8000, min=1000, max=500000), + IO.Int.Input("octree_resolution", default=256, min=16, max=512), + ], + outputs=[ + IO.Voxel.Output(), + ] + ) - CATEGORY = "latent/3d" + @classmethod + def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput: + voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return IO.NodeOutput(voxels) + + decode = execute # TODO: remove - def decode(self, vae, samples, num_chunks, octree_resolution): - voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) - return (voxels, ) def voxel_to_mesh(voxels, threshold=0.5, device=None): if device is None: @@ -396,24 +418,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): return final_vertices, faces -class MESH: - def __init__(self, vertices, faces): - self.vertices = vertices - self.faces = faces - -class VoxelToMeshBasic: +class VoxelToMeshBasic(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMeshBasic", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, threshold): + @classmethod + def execute(cls, voxel, threshold) -> IO.NodeOutput: vertices = [] faces = [] for x in voxel.data: @@ -421,21 +443,29 @@ class VoxelToMeshBasic: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) -class VoxelToMesh: + decode = execute # TODO: remove + + +class VoxelToMesh(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "algorithm": (["surface net", "basic"], ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMesh", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Combo.Input("algorithm", options=["surface net", "basic"]), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, algorithm, threshold): + @classmethod + def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput: vertices = [] faces = [] @@ -449,7 +479,9 @@ class VoxelToMesh: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + + decode = execute # TODO: remove def save_glb(vertices, faces, filepath, metadata=None): @@ -581,31 +613,32 @@ def save_glb(vertices, faces, filepath, metadata=None): return filepath -class SaveGLB: +class SaveGLB(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"mesh": ("MESH", ), - "filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + def define_schema(cls): + return IO.Schema( + node_id="SaveGLB", + category="3d", + is_output_node=True, + inputs=[ + IO.Mesh.Input("mesh"), + IO.String.Input("filename_prefix", default="mesh/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] + ) - RETURN_TYPES = () - FUNCTION = "save" - - OUTPUT_NODE = True - - CATEGORY = "3d" - - def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None): + @classmethod + def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] metadata = {} if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" @@ -616,15 +649,22 @@ class SaveGLB: "type": "output" }) counter += 1 - return {"ui": {"3d": results}} + return IO.NodeOutput(ui={"3d": results}) -NODE_CLASS_MAPPINGS = { - "EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2, - "Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning, - "Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView, - "VAEDecodeHunyuan3D": VAEDecodeHunyuan3D, - "VoxelToMeshBasic": VoxelToMeshBasic, - "VoxelToMesh": VoxelToMesh, - "SaveGLB": SaveGLB, -} +class Hunyuan3dExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentHunyuan3Dv2, + Hunyuan3Dv2Conditioning, + Hunyuan3Dv2ConditioningMultiView, + VAEDecodeHunyuan3D, + VoxelToMeshBasic, + VoxelToMesh, + SaveGLB, + ] + + +async def comfy_entrypoint() -> Hunyuan3dExtension: + return Hunyuan3dExtension() From 7601e89255cde24667d3b4e6022f1385d901748b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 19 Nov 2025 17:17:15 -0800 Subject: [PATCH 109/231] Fix workflow name. (#10806) --- .github/workflows/release-stable-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index f7de3a7c3..9274b4170 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -14,7 +14,7 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release NVIDIA Default (cu129)" + name: "Release NVIDIA Default (cu130)" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} From 394348f5caaa062eac11a57e2997aacccd4246eb Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 20 Nov 2025 03:44:04 +0200 Subject: [PATCH 110/231] feat(api-nodes): add Topaz API nodes (#10755) --- comfy_api_nodes/apis/topaz_api.py | 133 ++++++++++ comfy_api_nodes/nodes_topaz.py | 421 ++++++++++++++++++++++++++++++ comfy_api_nodes/util/client.py | 9 +- nodes.py | 1 + 4 files changed, 560 insertions(+), 4 deletions(-) create mode 100644 comfy_api_nodes/apis/topaz_api.py create mode 100644 comfy_api_nodes/nodes_topaz.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz_api.py new file mode 100644 index 000000000..4d9e62e72 --- /dev/null +++ b/comfy_api_nodes/apis/topaz_api.py @@ -0,0 +1,133 @@ +from typing import Optional, Union + +from pydantic import BaseModel, Field + + +class ImageEnhanceRequest(BaseModel): + model: str = Field("Reimagine") + output_format: str = Field("jpeg") + subject_detection: str = Field("All") + face_enhancement: bool = Field(True) + face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false") + face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false") + source_url: str = Field(...) + output_width: Optional[int] = Field(None) + output_height: Optional[int] = Field(None) + crop_to_fill: bool = Field(False) + prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance") + creativity: int = Field(3, description="Creativity settings range from 1 to 9") + face_preservation: str = Field("true", description="To preserve the identity of characters") + color_preservation: str = Field("true", description="To preserve the original color") + + +class ImageAsyncTaskResponse(BaseModel): + process_id: str = Field(...) + + +class ImageStatusResponse(BaseModel): + process_id: str = Field(...) + status: str = Field(...) + progress: Optional[int] = Field(None) + credits: int = Field(...) + + +class ImageDownloadResponse(BaseModel): + download_url: str = Field(...) + expiry: int = Field(...) + + +class Resolution(BaseModel): + width: int = Field(...) + height: int = Field(...) + + +class CreateCreateVideoRequestSource(BaseModel): + container: str = Field(...) + size: int = Field(..., description="Size of the video file in bytes") + duration: int = Field(..., description="Duration of the video file in seconds") + frameCount: int = Field(..., description="Total number of frames in the video") + frameRate: int = Field(...) + resolution: Resolution = Field(...) + + +class VideoFrameInterpolationFilter(BaseModel): + model: str = Field(...) + slowmo: Optional[int] = Field(None) + fps: int = Field(...) + duplicate: bool = Field(...) + duplicate_threshold: float = Field(...) + + +class VideoEnhancementFilter(BaseModel): + model: str = Field(...) + auto: Optional[str] = Field(None, description="Auto, Manual, Relative") + focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects") + compression: Optional[float] = Field(None, description="Strength of compression recovery") + details: Optional[float] = Field(None, description="Amount of detail reconstruction") + prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing") + noise: Optional[float] = Field(None, description="Amount of noise reduction") + halo: Optional[float] = Field(None, description="Amount of halo reduction") + preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength") + blur: Optional[float] = Field(None, description="Amount of sharpness applied") + grain: Optional[float] = Field(None, description="Grain after AI model processing") + grainSize: Optional[float] = Field(None, description="Size of generated grain") + recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") + creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + + +class OutputInformationVideo(BaseModel): + resolution: Resolution = Field(...) + frameRate: int = Field(...) + audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert") + audioTransfer: str = Field(..., description="Copy, Convert, None") + dynamicCompressionLevel: str = Field(..., description="Low, Mid, High") + + +class Overrides(BaseModel): + isPaidDiffusion: bool = Field(True) + + +class CreateVideoRequest(BaseModel): + source: CreateCreateVideoRequestSource = Field(...) + filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + output: OutputInformationVideo = Field(...) + overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) + + +class CreateVideoResponse(BaseModel): + requestId: str = Field(...) + + +class VideoAcceptResponse(BaseModel): + uploadId: str = Field(...) + urls: list[str] = Field(...) + + +class VideoCompleteUploadRequestPart(BaseModel): + partNum: int = Field(...) + eTag: str = Field(...) + + +class VideoCompleteUploadRequest(BaseModel): + uploadResults: list[VideoCompleteUploadRequestPart] = Field(...) + + +class VideoCompleteUploadResponse(BaseModel): + message: str = Field(..., description="Confirmation message") + + +class VideoStatusResponseEstimates(BaseModel): + cost: list[int] = Field(...) + + +class VideoStatusResponseDownloadUrl(BaseModel): + url: str = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str = Field(...) + estimates: Optional[VideoStatusResponseEstimates] = Field(None) + progress: Optional[float] = Field(None) + message: Optional[str] = Field("") + download: Optional[VideoStatusResponseDownloadUrl] = Field(None) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py new file mode 100644 index 000000000..79c7bf43d --- /dev/null +++ b/comfy_api_nodes/nodes_topaz.py @@ -0,0 +1,421 @@ +import builtins +from io import BytesIO + +import aiohttp +import torch +from typing_extensions import override + +from comfy_api.input.video_types import VideoInput +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_container_format_is_mp4, +) + +UPSCALER_MODELS_MAP = { + "Starlight (Astra) Fast": "slf-1", + "Starlight (Astra) Creative": "slc-1", +} +UPSCALER_VALUES_MAP = { + "FullHD (1080p)": 1920, + "4K (2160p)": 3840, +} + + +class TopazImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazImageEnhance", + display_name="Topaz Image Enhance", + category="api node/image/Topaz", + description="Industry-standard upscaling and image enhancement.", + inputs=[ + IO.Combo.Input("model", options=["Reimagine"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional text prompt for creative upscaling guidance.", + optional=True, + ), + IO.Combo.Input( + "subject_detection", + options=["All", "Foreground", "Background"], + optional=True, + ), + IO.Boolean.Input( + "face_enhancement", + default=True, + optional=True, + tooltip="Enhance faces (if present) during processing.", + ), + IO.Float.Input( + "face_enhancement_creativity", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Set the creativity level for face enhancement.", + ), + IO.Float.Input( + "face_enhancement_strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Controls how sharp enhanced faces are relative to the background.", + ), + IO.Boolean.Input( + "crop_to_fill", + default=False, + optional=True, + tooltip="By default, the image is letterboxed when the output aspect ratio differs. " + "Enable to crop the image to fill the output dimensions.", + ), + IO.Int.Input( + "output_width", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).", + ), + IO.Int.Input( + "output_height", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to output in the same height as original or output width.", + ), + IO.Int.Input( + "creativity", + default=3, + min=1, + max=9, + step=1, + display_mode=IO.NumberDisplay.slider, + optional=True, + ), + IO.Boolean.Input( + "face_preservation", + default=True, + optional=True, + tooltip="Preserve subjects' facial identity.", + ), + IO.Boolean.Input( + "color_preservation", + default=True, + optional=True, + tooltip="Preserve the original colors.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str = "", + subject_detection: str = "All", + face_enhancement: bool = True, + face_enhancement_creativity: float = 1.0, + face_enhancement_strength: float = 0.8, + crop_to_fill: bool = False, + output_width: int = 0, + output_height: int = 0, + creativity: int = 3, + face_preservation: bool = True, + color_preservation: bool = True, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), + response_model=topaz_api.ImageAsyncTaskResponse, + data=topaz_api.ImageEnhanceRequest( + model=model, + prompt=prompt, + subject_detection=subject_detection, + face_enhancement=face_enhancement, + face_enhancement_creativity=face_enhancement_creativity, + face_enhancement_strength=face_enhancement_strength, + crop_to_fill=crop_to_fill, + output_width=output_width if output_width else None, + output_height=output_height if output_height else None, + creativity=creativity, + face_preservation=str(face_preservation).lower(), + color_preservation=str(color_preservation).lower(), + source_url=download_url[0], + output_format="png", + ), + content_type="multipart/form-data", + ) + + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), + response_model=topaz_api.ImageStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: x.credits * 0.08, + poll_interval=8.0, + max_poll_attempts=160, + estimated_duration=60, + ) + + results = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), + response_model=topaz_api.ImageDownloadResponse, + monitor_progress=False, + ) + return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) + + +class TopazVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhance", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.Boolean.Input("upscaler_enabled", default=True), + IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input( + "upscaler_creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creativity level (applies only to Starlight (Astra) Creative).", + optional=True, + ), + IO.Boolean.Input("interpolation_enabled", default=False, optional=True), + IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + optional=True, + ), + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + optional=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + optional=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + optional=True, + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: VideoInput, + upscaler_enabled: bool, + upscaler_model: str, + upscaler_resolution: str, + upscaler_creativity: str = "low", + interpolation_enabled: bool = False, + interpolation_model: str = "apo-8", + interpolation_slowmo: int = 1, + interpolation_frame_rate: int = 60, + interpolation_duplicate: bool = False, + interpolation_duplicate_threshold: float = 0.01, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + if upscaler_enabled is False and interpolation_enabled is False: + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + src_width, src_height = video.get_dimensions() + video_components = video.get_components() + src_frame_rate = int(video_components.frame_rate) + duration_sec = video.get_duration() + estimated_frames = int(duration_sec * src_frame_rate) + validate_container_format_is_mp4(video) + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_enabled: + target_width = UPSCALER_VALUES_MAP[upscaler_resolution] + target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + filters.append( + topaz_api.VideoEnhancementFilter( + model=UPSCALER_MODELS_MAP[upscaler_model], + creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + ), + ) + if interpolation_enabled: + target_frame_rate = interpolation_frame_rate + filters.append( + topaz_api.VideoFrameInterpolationFilter( + model=interpolation_model, + slowmo=interpolation_slowmo, + fps=interpolation_frame_rate, + duplicate=interpolation_duplicate, + duplicate_threshold=interpolation_duplicate_threshold, + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=topaz_api.CreateVideoResponse, + data=topaz_api.CreateVideoRequest( + source=topaz_api.CreateCreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=estimated_frames, + frameRate=src_frame_rate, + resolution=topaz_api.Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=topaz_api.OutputInformationVideo( + resolution=topaz_api.Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=topaz_api.VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=topaz_api.VideoCompleteUploadResponse, + data=topaz_api.VideoCompleteUploadRequest( + uploadResults=[ + topaz_api.VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=topaz_api.VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + +class TopazExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TopazImageEnhance, + TopazVideoEnhance, + ] + + +async def comfy_entrypoint() -> TopazExtension: + return TopazExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 2d5dcd648..ad6e3c0d0 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -77,9 +77,9 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"] -FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] async def sync_op( @@ -424,7 +424,8 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - display_lines.append(f"Price: ${float(price):,.4f}") + p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + display_lines.append(f"Price: ${p}") if text is not None: display_lines.append(text) if display_lines: diff --git a/nodes.py b/nodes.py index f6aeedc78..ac14e39a7 100644 --- a/nodes.py +++ b/nodes.py @@ -2359,6 +2359,7 @@ async def init_builtin_api_nodes(): "nodes_pika.py", "nodes_runway.py", "nodes_sora.py", + "nodes_topaz.py", "nodes_tripo.py", "nodes_moonvalley.py", "nodes_rodin.py", From cb96d4d18c78ee09d5fd70954ffcb4ad2c7f0d7a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 19 Nov 2025 20:56:23 -0800 Subject: [PATCH 111/231] Disable workaround on newer cudnn. (#10807) --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2a90a5ba2..640622fd1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -58,7 +58,8 @@ except (ModuleNotFoundError, TypeError): NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False try: if comfy.model_management.is_nvidia(): - if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + cudnn_version = torch.backends.cudnn.version() + if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): #TODO: change upper bound version once it's fixed' NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True logging.info("working around nvidia conv3d memory bug.") From 87b0359392219841c2214e1eb06678840cae470e Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 19 Nov 2025 22:36:56 -0800 Subject: [PATCH 112/231] Update server templates handler to use new multi-package distribution (comfyui-workflow-templates versions >=0.3) (#10791) * update templates for monorepo * refactor --- app/frontend_management.py | 67 ++++++++++++++++++++++++++++++++++++-- requirements.txt | 2 +- server.py | 32 ++++++++++++++---- 3 files changed, 92 insertions(+), 9 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index cce0c117d..bdaa85812 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -10,7 +10,8 @@ import importlib from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict, Optional +from typing import Dict, TypedDict, Optional +from aiohttp import web from importlib.metadata import version import requests @@ -257,7 +258,54 @@ comfyui-frontend-package is not installed. sys.exit(-1) @classmethod - def templates_path(cls) -> str: + def template_asset_map(cls) -> Optional[Dict[str, str]]: + """Return a mapping of template asset names to their absolute paths.""" + try: + from comfyui_workflow_templates import ( + get_asset_path, + iter_templates, + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + return None + + try: + template_entries = list(iter_templates()) + except Exception as exc: + logging.error(f"Failed to enumerate workflow templates: {exc}") + return None + + asset_map: Dict[str, str] = {} + try: + for entry in template_entries: + for asset in entry.assets: + asset_map[asset.filename] = get_asset_path( + entry.template_id, asset.filename + ) + except Exception as exc: + logging.error(f"Failed to resolve template asset paths: {exc}") + return None + + if not asset_map: + logging.error("No workflow template assets found. Did the packages install correctly?") + return None + + return asset_map + + + @classmethod + def legacy_templates_path(cls) -> Optional[str]: + """Return the legacy templates directory shipped inside the meta package.""" try: import comfyui_workflow_templates @@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed. ********** ERROR *********** """.strip() ) + return None @classmethod def embedded_docs_path(cls) -> str: @@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed. logging.info("Falling back to the default frontend.") check_frontend_version() return cls.default_frontend_path() + @classmethod + def template_asset_handler(cls): + assets = cls.template_asset_map() + if not assets: + return None + + async def serve_template(request: web.Request) -> web.StreamResponse: + rel_path = request.match_info.get("path", "") + target = assets.get(rel_path) + if target is None: + raise web.HTTPNotFound() + return web.FileResponse(target) + + return serve_template diff --git a/requirements.txt b/requirements.txt index 249c36dee..36c39f338 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.8 -comfyui-workflow-templates==0.2.11 +comfyui-workflow-templates==0.3.1 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/server.py b/server.py index d059d3dc9..d9d5c491f 100644 --- a/server.py +++ b/server.py @@ -30,7 +30,7 @@ import comfy.model_management from comfy_api import feature_flags import node_helpers from comfyui_version import __version__ -from app.frontend_management import FrontendManager +from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager @@ -849,11 +849,31 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) - workflow_templates_path = FrontendManager.templates_path() - if workflow_templates_path: - self.app.add_routes([ - web.static('/templates', workflow_templates_path) - ]) + installed_templates_version = FrontendManager.get_installed_templates_version() + use_legacy_templates = True + if installed_templates_version: + try: + use_legacy_templates = ( + parse_version(installed_templates_version) + < parse_version("0.3.0") + ) + except Exception as exc: + logging.warning( + "Unable to parse templates version '%s': %s", + installed_templates_version, + exc, + ) + + if use_legacy_templates: + workflow_templates_path = FrontendManager.legacy_templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + else: + handler = FrontendManager.template_asset_handler() + if handler: + self.app.router.add_get("/templates/{path:.*}", handler) # Serve embedded documentation from the package embedded_docs_path = FrontendManager.embedded_docs_path() From f5e66d5e47271253edad5c4eddd817b0d6a23340 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:08:03 -0800 Subject: [PATCH 113/231] Fix ImageBatch with different channel count. (#10815) --- nodes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nodes.py b/nodes.py index ac14e39a7..75e820e66 100644 --- a/nodes.py +++ b/nodes.py @@ -1852,6 +1852,10 @@ class ImageBatch: CATEGORY = "image" def batch(self, image1, image2): + if image1.shape[-1] != image2.shape[-1]: + channels = min(image1.shape[-1], image2.shape[-1]) + image1 = image1[..., :channels] + image2 = image2[..., :channels] if image1.shape[1:] != image2.shape[1:]: image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) s = torch.cat((image1, image2), dim=0) From 9e00ce5b76ec04be37375310512a443605b95077 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 20 Nov 2025 14:42:46 -0800 Subject: [PATCH 114/231] Make Batch Images node add alpha channel when one of the inputs has it (#10816) * When one Batch Image input has alpha and one does not, add empty alpha channel * Use torch.nn.functional.pad --- nodes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 75e820e66..030371633 100644 --- a/nodes.py +++ b/nodes.py @@ -1853,9 +1853,10 @@ class ImageBatch: def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: - channels = min(image1.shape[-1], image2.shape[-1]) - image1 = image1[..., :channels] - image2 = image2[..., :channels] + if image1.shape[-1] > image2.shape[-1]: + image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0) + else: + image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0) if image1.shape[1:] != image2.shape[1:]: image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) s = torch.cat((image1, image2), dim=0) From 7b8389578e88dcd13b1cf6aea5404047298c9183 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 21 Nov 2025 02:17:47 +0200 Subject: [PATCH 115/231] feat(api-nodes): add Nano Banana Pro (#10814) * feat(api-nodes): add Nano Banana Pro * frontend bump to 1.28.9 --- comfy_api_nodes/apis/gemini_api.py | 5 +- comfy_api_nodes/nodes_gemini.py | 205 ++++++++++++++++++++++++++++- comfy_api_nodes/util/client.py | 13 +- requirements.txt | 2 +- 4 files changed, 215 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f63e02693..710f173f1 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -68,7 +68,7 @@ class GeminiTextPart(BaseModel): class GeminiContent(BaseModel): - parts: list[GeminiPart] = Field(...) + parts: list[GeminiPart] = Field([]) role: GeminiRole = Field(..., examples=["user"]) @@ -120,7 +120,7 @@ class GeminiGenerationConfig(BaseModel): class GeminiImageConfig(BaseModel): aspectRatio: str | None = Field(None) - resolution: str | None = Field(None) + imageSize: str | None = Field(None) class GeminiImageGenerationConfig(GeminiGenerationConfig): @@ -227,3 +227,4 @@ class GeminiGenerateContentResponse(BaseModel): candidates: list[GeminiCandidate] | None = Field(None) promptFeedback: GeminiPromptFeedback | None = Field(None) usageMetadata: GeminiUsageMetadata | None = Field(None) + modelVersion: str | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 6e746eebd..be752c885 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,11 +29,13 @@ from comfy_api_nodes.apis.gemini_api import ( GeminiMimeType, GeminiPart, GeminiRole, + Modality, ) from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + get_number_of_images, sync_op, tensor_to_base64_string, validate_string, @@ -147,6 +149,49 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te return torch.cat(image_tensors, dim=0) +def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None: + if not response.modelVersion: + return None + # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing + if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + input_tokens_price = 1.25 + output_text_tokens_price = 10.0 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-flash", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-image-preview", + "gemini-2.5-flash-image", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 30.0 + elif response.modelVersion == "gemini-3-pro-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3-pro-image-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 120.0 + else: + return None + final_price = response.usageMetadata.promptTokenCount * input_tokens_price + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.thoughtsTokenCount: + final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount + return final_price / 1_000_000.0 + + class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -314,6 +359,7 @@ class GeminiNode(IO.ComfyNode): ] ), response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, ) output_text = get_text_from_response(response) @@ -476,6 +522,13 @@ class GeminiImage(IO.ComfyNode): "or otherwise generates 1:1 squares.", optional=True, ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -498,6 +551,7 @@ class GeminiImage(IO.ComfyNode): images: torch.Tensor | None = None, files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", + response_modalities: str = "IMAGE+TEXT", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) parts: list[GeminiPart] = [GeminiPart(text=prompt)] @@ -520,17 +574,16 @@ class GeminiImage(IO.ComfyNode): GeminiContent(role=GeminiRole.user, parts=parts), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT", "IMAGE"], + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, ), ), response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, ) - output_image = get_image_from_response(response) output_text = get_text_from_response(response) if output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", @@ -551,9 +604,150 @@ class GeminiImage(IO.ComfyNode): "display_component", render_spec, ) + return IO.NodeOutput(get_image_from_response(response), output_text) - output_text = output_text or "Empty response from Gemini model..." - return IO.NodeOutput(output_image, output_text) + +class GeminiImage2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImage2Node", + display_name="Nano Banana Pro (Google Gemini Image)", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["gemini-3-pro-image-preview"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, generates a 1:1 square.", + ), + IO.Combo.Input( + "resolution", + options=["1K", "2K", "4K"], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + images: torch.Tensor | None = None, + files: list[GeminiPart] | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(create_image_parts(images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + response = await sync_op( + cls, + ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + ), + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + + output_text = get_text_from_response(response) + if output_text: + render_spec = { + "node_id": cls.hidden.unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) + return IO.NodeOutput(get_image_from_response(response), output_text) class GeminiExtension(ComfyExtension): @@ -562,6 +756,7 @@ class GeminiExtension(ComfyExtension): return [ GeminiNode, GeminiImage, + GeminiImage2, GeminiInputFiles, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index ad6e3c0d0..bf01d7d36 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -63,6 +63,7 @@ class _RequestConfig: estimated_total: Optional[int] = None final_label_on_success: Optional[str] = "Completed" progress_origin_ts: Optional[float] = None + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None @dataclass @@ -87,6 +88,7 @@ async def sync_op( endpoint: ApiEndpoint, *, response_model: Type[M], + price_extractor: Optional[Callable[[M], Optional[float]]] = None, data: Optional[BaseModel] = None, files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, content_type: str = "application/json", @@ -104,6 +106,7 @@ async def sync_op( raw = await sync_op_raw( cls, endpoint, + price_extractor=_wrap_model_extractor(response_model, price_extractor), data=data, files=files, content_type=content_type, @@ -175,6 +178,7 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, data: Optional[Union[dict[str, Any], BaseModel]] = None, files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, content_type: str = "application/json", @@ -216,6 +220,7 @@ async def sync_op_raw( estimated_total=estimated_duration, final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, + price_extractor=price_extractor, ) return await _request_base(cfg, expect_binary=as_binary) @@ -425,7 +430,8 @@ def _display_text( display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: p = f"{float(price):,.4f}".rstrip("0").rstrip(".") - display_lines.append(f"Price: ${p}") + if p != "0": + display_lines.append(f"Price: ${p}") if text is not None: display_lines.append(text) if display_lines: @@ -581,6 +587,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): delay = cfg.retry_delay operation_succeeded: bool = False final_elapsed_seconds: Optional[int] = None + extracted_price: Optional[float] = None while True: attempt += 1 stop_event = asyncio.Event() @@ -768,6 +775,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): except json.JSONDecodeError: payload = {"_raw": text} response_content_to_log = payload if isinstance(payload, dict) else text + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) try: @@ -872,7 +881,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): else int(time.monotonic() - start_time) ), estimated_total=cfg.estimated_total, - price=None, + price=extracted_price, is_queued=False, processing_elapsed_seconds=final_elapsed_seconds, ) diff --git a/requirements.txt b/requirements.txt index 36c39f338..8c1946f3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.28.8 +comfyui-frontend-package==1.28.9 comfyui-workflow-templates==0.3.1 comfyui-embedded-docs==0.3.1 torch From b75d349f25ccb702895c6f1b8af7aded63a7f7e2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 21 Nov 2025 02:33:54 +0200 Subject: [PATCH 116/231] fix(KlingLipSyncAudioToVideoNode): convert audio to mp3 format (#10811) --- comfy_api_nodes/nodes_kling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 7b23e9cf9..36852038b 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -518,7 +518,9 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(cls, audio) + audio_url = await upload_audio_to_comfyapi( + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None From 10e90a5757906ecdb71b84d41173813d7f62c140 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 20 Nov 2025 18:20:52 -0800 Subject: [PATCH 117/231] bump comfyui-workflow-templates for nano banana 2 (#10818) * bump templates * bump templates --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8c1946f3d..624aa7362 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.9 -comfyui-workflow-templates==0.3.1 +comfyui-workflow-templates==0.6.0 comfyui-embedded-docs==0.3.1 torch torchsde From 943b3b615d40542ea19bc8ff8ad2950c0a094605 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:44:43 -0800 Subject: [PATCH 118/231] HunyuanVideo 1.5 (#10819) * init * update * Update model.py * Update model.py * remove print * Fix text encoding * Prevent empty negative prompt Really doesn't work otherwise * fp16 works * I2V * Update model_base.py * Update nodes_hunyuan.py * Better latent rgb factors * Use the correct sigclip output... * Support HunyuanVideo1.5 SR model * whitespaces... * Proper latent channel count * SR model fixes This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already * vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. * Support HunyuanVideo15 latent resampler * fix * Some cleanup Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Proper hyvid15 I2V channels Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. * Bugfix for the HunyuanVideo15 SR model * vae_refiner: roll the convolution through temporal II Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code. * Allow any number of input frames in VAE. * Better VAE encode mem estimation. * Lowvram fix. * Fix hunyuan image 2.1 refiner. * Fix mistake. * Name changes. * Rename. * Whitespace. * Fix. * Fix. --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Rattus --- comfy/latent_formats.py | 60 ++++ comfy/ldm/hunyuan_video/model.py | 54 +++- comfy/ldm/hunyuan_video/upsampler.py | 120 ++++++++ comfy/ldm/hunyuan_video/vae_refiner.py | 284 +++++++++++------- comfy/model_base.py | 91 ++++++ comfy/model_detection.py | 10 + comfy/sd.py | 12 +- comfy/supported_models.py | 50 ++- comfy/text_encoders/hunyuan_video.py | 9 + comfy/text_encoders/qwen_image.py | 4 +- comfy_api/latest/_io.py | 4 + comfy_extras/nodes_hunyuan.py | 201 ++++++++++++- folder_paths.py | 2 + .../put_latent_upscale_models_here | 0 nodes.py | 2 +- 15 files changed, 777 insertions(+), 126 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/upsampler.py create mode 100644 models/latent_upscale_models/put_latent_upscale_models_here diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 77e642a94..204fc048d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -611,6 +611,66 @@ class HunyuanImage21Refiner(LatentFormat): latent_dimensions = 3 scale_factor = 1.03682 + def process_in(self, latent): + out = latent * self.scale_factor + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + + def process_out(self, latent): + z = latent / self.scale_factor + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + return z + +class HunyuanVideo15(LatentFormat): + latent_rgb_factors = [ + [ 0.0568, -0.0521, -0.0131], + [ 0.0014, 0.0735, 0.0326], + [ 0.0186, 0.0531, -0.0138], + [-0.0031, 0.0051, 0.0288], + [ 0.0110, 0.0556, 0.0432], + [-0.0041, -0.0023, -0.0485], + [ 0.0530, 0.0413, 0.0253], + [ 0.0283, 0.0251, 0.0339], + [ 0.0277, -0.0372, -0.0093], + [ 0.0393, 0.0944, 0.1131], + [ 0.0020, 0.0251, 0.0037], + [-0.0017, 0.0012, 0.0234], + [ 0.0468, 0.0436, 0.0203], + [ 0.0354, 0.0439, -0.0233], + [ 0.0090, 0.0123, 0.0346], + [ 0.0382, 0.0029, 0.0217], + [ 0.0261, -0.0300, 0.0030], + [-0.0088, -0.0220, -0.0283], + [-0.0272, -0.0121, -0.0363], + [-0.0664, -0.0622, 0.0144], + [ 0.0414, 0.0479, 0.0529], + [ 0.0355, 0.0612, -0.0247], + [ 0.0147, 0.0264, 0.0174], + [ 0.0438, 0.0038, 0.0542], + [ 0.0431, -0.0573, -0.0033], + [-0.0162, -0.0211, -0.0406], + [-0.0487, -0.0295, -0.0393], + [ 0.0005, -0.0109, 0.0253], + [ 0.0296, 0.0591, 0.0353], + [ 0.0119, 0.0181, -0.0306], + [-0.0085, -0.0362, 0.0229], + [ 0.0005, -0.0106, 0.0242] + ] + + latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] + latent_channels = 32 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 5132e6c07..f75c6e0e1 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -6,7 +6,6 @@ import comfy.ldm.flux.layers import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention - from dataclasses import dataclass from einops import repeat @@ -42,6 +41,8 @@ class HunyuanVideoParams: guidance_embed: bool byt5: bool meanflow: bool + use_cond_type_embedding: bool + vision_in_dim: int class SelfAttentionRef(nn.Module): @@ -157,7 +158,10 @@ class TokenRefiner(nn.Module): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise - c = x.sum(dim=1) / x.shape[1] + if x.dtype == torch.float16: + c = x.float().sum(dim=1) / x.shape[1] + else: + c = x.sum(dim=1) / x.shape[1] c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) @@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + params = HunyuanVideoParams(**kwargs) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels + self.use_cond_type_embedding = params.use_cond_type_embedding + self.vision_in_dim = params.vision_in_dim if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module): if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) + # HunyuanVideo 1.5 specific modules + if self.vision_in_dim is not None: + from comfy.ldm.wan.model import MLPProj + self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) + else: + self.vision_in = None + if self.use_cond_type_embedding: + # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature + self.cond_type_embedding = nn.Embedding(3, self.hidden_size) + else: + self.cond_type_embedding = None + def forward_orig( self, img: Tensor, @@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor = None, txt_byt5=None, + clip_fea=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module): txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) + if self.cond_type_embedding is not None: + self.cond_type_embedding.to(txt.device) + cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + txt = txt + cond_emb.to(txt.dtype) + if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5 + else: + txt = torch.cat((txt, txt_byt5), dim=1) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - txt = torch.cat((txt, txt_byt5), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + if clip_fea is not None: + txt_vision_states = self.vision_in(clip_fea) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) + txt_vision_states = txt_vision_states + cond_emb + txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) + extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -430,14 +470,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -445,5 +485,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py new file mode 100644 index 000000000..9f5e91a59 --- /dev/null +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +import model_management, model_patcher + +class SRResidualCausalBlock3D(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + +class SRModel3DV2(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + num_blocks: int = 6, + global_residual: bool = False, + ): + super().__init__() + self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3) + self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) + self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3) + self.global_residual = bool(global_residual) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + y = self.in_conv(x) + for blk in self.blocks: + y = blk(y) + y = self.out_conv(y) + if self.global_residual and (y.shape == residual.shape): + y = y + residual + return y + + +class Upsampler(nn.Module): + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int = 2, + ): + super().__init__() + self.num_res_blocks = num_res_blocks + self.block_out_channels = block_out_channels + self.z_channels = z_channels + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3) + + self.up = nn.ModuleList() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_shortcut=False, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3) + + def forward(self, z): + """ + Args: + z: (B, C, T, H, W) + target_shape: (H, W) + """ + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # upsampling + for stage in self.up: + for blk in stage.block: + x = blk(x) + + out = self.conv_out(F.silu(self.norm_out(x))) + return out + +UPSAMPLERS = { + "720p": SRModel3DV2, + "1080p": Upsampler, +} + +class HunyuanVideo15SRModel(): + def __init__(self, model_type, config): + self.load_device = model_management.vae_device() + offload_device = model_management.vae_offload_device() + self.dtype = model_management.vae_dtype(self.load_device) + self.model_class = UPSAMPLERS.get(model_type) + self.model = self.model_class(**config).eval() + + self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=True) + + def get_sd(self): + return self.model.state_dict() + + def resample_latent(self, latent): + model_management.load_model_gpu(self.patcher) + return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c2a0b507d..9f750dcc4 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -4,8 +4,40 @@ import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize import comfy.ops import comfy.ldm.models.autoencoder +import comfy.model_management ops = comfy.ops.disable_weight_init +class NoPadConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + if isinstance(op, NoPadConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + + out = op(x) + + return out + + class RMS_norm(nn.Module): def __init__(self, dim): super().__init__() @@ -14,7 +46,7 @@ class RMS_norm(nn.Module): self.gamma = nn.Parameter(torch.empty(shape)) def forward(self, x): - return F.normalize(x, dim=1) * self.scale * self.gamma + return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): @@ -27,11 +59,12 @@ class DnSmpl(nn.Module): self.tds = tds self.gs = fct * ic // oc - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tds else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + + if self.tds and self.refiner_vae and conv_carry_in is None: - if self.tds and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -39,14 +72,7 @@ class DnSmpl(nn.Module): hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = torch.cat([hf, hf], dim=1) - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nf = frms // r1 - hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) - hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -54,34 +80,32 @@ class DnSmpl(nn.Module): xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) B, C, T, H, W = xf.shape - xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) + xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2) - xn = x[:, :, 1:, :, :] - b, ci, frms, ht, wd = xn.shape - nf = frms // r1 - xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) - xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = xn.shape - xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape + x = x[:, :, 1:, :, :] - nf = frms // r1 - h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) - h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + if h.shape[2] == 0: + return hf + xf - b, ci, frms, ht, wd = x.shape - nf = frms // r1 - sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) - sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = sc.shape - sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - return h + sc + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + x = x.permute(0, 3, 5, 7, 1, 2, 4, 6) + x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = x.shape + x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + if self.tds and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x class UpSmpl(nn.Module): @@ -94,11 +118,11 @@ class UpSmpl(nn.Module): self.tus = tus self.rp = fct * oc // ic - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tus else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) - if self.tus and self.refiner_vae: + if self.tus and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -107,14 +131,7 @@ class UpSmpl(nn.Module): hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf[:, : hf.shape[1] // 2] - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nc = c // (r1 * 2 * 2) - hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) - hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -125,29 +142,43 @@ class UpSmpl(nn.Module): xf = xf.permute(0, 3, 4, 5, 1, 6, 2) xf = xf.reshape(b, nc, f, ht * 2, wd * 2) - xn = x[:, :, 1:, :, :] - xn = xn.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = xn.shape - nc = c // (r1 * 2 * 2) - xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) - xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape - nc = c // (r1 * 2 * 2) - h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) - h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) - h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + x = x[:, :, 1:, :, :] - sc = x.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = sc.shape - nc = c // (r1 * 2 * 2) - sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) - sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) - sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) - return h + sc + x = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = x.shape + nc = c // (r1 * 2 * 2) + x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) + x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) + x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + if self.tus and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x + +class HunyuanRefinerResnetBlock(ResnetBlock): + def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): + super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + + def forward(self, x, conv_carry_in=None, conv_carry_out=None): + h = x + h = [ self.swish(self.norm1(x)) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + h = [ self.dropout(self.swish(self.norm2(h))) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x+h class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, @@ -160,7 +191,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = NoPadConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -175,10 +206,9 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - temb_channels=0, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -188,9 +218,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -201,31 +231,50 @@ class Encoder(nn.Module): if not self.refiner_vae and x.shape[2] == 1: x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) - x = self.conv_in(x) + if self.refiner_vae: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.ffactor_temporal: + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2) + x = xl + else: + x = [x] + out = [] - for stage in self.down: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'downsample'): - x = stage.downsample(x) + conv_carry_in = None - x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + x1 = [ x1 ] + x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for stage in self.down: + for blk in stage.block: + x1 = blk(x1, conv_carry_in, conv_carry_out) + if hasattr(stage, 'downsample'): + x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) + + out.append(x1) + conv_carry_in = conv_carry_out + + if len(out) > 1: + out = torch.cat(out, dim=2) + else: + out = out[0] + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) + del out b, c, t, h, w = x.shape grp = c // (self.z_channels << 1) skip = x.view(b, c // grp, grp, t, h, w).mean(2) - out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip if self.refiner_vae: out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() - return out class Decoder(nn.Module): @@ -239,7 +288,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = NoPadConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -249,9 +298,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -259,10 +308,9 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - temb_channels=0, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -275,27 +323,41 @@ class Decoder(nn.Module): self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - if self.refiner_vae: - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] - - x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) - for stage in self.up: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'upsample'): - x = stage.upsample(x) + if self.refiner_vae: + x = torch.split(x, 2, dim=2) + else: + x = [ x ] + out = [] - out = self.conv_out(F.silu(self.norm_out(x))) + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + for stage in self.up: + for blk in stage.block: + x1 = blk(x1, conv_carry_in, conv_carry_out) + if hasattr(stage, 'upsample'): + x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) + + x1 = [ F.silu(self.norm_out(x1)) ] + x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out) + out.append(x1) + conv_carry_in = conv_carry_out + del x + + if len(out) > 1: + out = torch.cat(out, dim=2) + else: + out = out[0] if not self.refiner_vae: if z.shape[-3] == 1: out = out[:, :, -1:] return out + diff --git a/comfy/model_base.py b/comfy/model_base.py index 7c788d085..e14b552c5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1536,3 +1536,94 @@ class HunyuanImage21Refiner(HunyuanImage21): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(True) return out + +class HunyuanVideo15(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1 + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state) + + return out + +class HunyuanVideo15_SR_Distilled(HunyuanVideo15): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device()) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + #image = self.process_latent_in(image) # scaling wasn't applied in reference code + image = utils.resize_to_batch_size(image, noise.shape[0]) + lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] + else: + image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(False) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3142a7fc3..0131ca25a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -186,6 +186,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 + + # HunyuanVideo 1.5 + if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["use_cond_type_embedding"] = True + else: + dit_config["use_cond_type_embedding"] = False + if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: + dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + else: + dit_config["vision_in_dim"] = None return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) diff --git a/comfy/sd.py b/comfy/sd.py index 9e5ebbf15..dc0905ada 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -441,20 +441,20 @@ class VAE: elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.latent_channels = 64 + self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 - self.not_video = True + self.not_video = False self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) - self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -911,6 +911,7 @@ class CLIPType(Enum): OMNIGEN2 = 17 QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 + HUNYUAN_VIDEO_15 = 20 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1126,6 +1127,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.HUNYUAN_VIDEO_15: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4064bdae1..2e64b85e8 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1374,6 +1374,54 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +class HunyuanVideo15(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + } + + sampling_settings = { + "shift": 7.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + + +class HunyuanVideo15_SR_Distilled(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + "in_channels": 98, + } + + sampling_settings = { + "shift": 2.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15_SR_Distilled(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index b02148b33..557094f49 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -1,6 +1,7 @@ from comfy import sd1_clip import comfy.model_management import comfy.text_encoders.llama +from .hunyuan_image import HunyuanImageTokenizer from transformers import LlamaTokenizerFast import torch import os @@ -73,6 +74,14 @@ class HunyuanVideoTokenizer: return {} +class HunyuanVideo15Tokenizer(HunyuanImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs) + class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index 40fa67937..c0d32a6ef 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): skip_template = False if text.startswith('<|im_start|>'): skip_template = True if text.startswith('<|start_header_id|>'): skip_template = True + if prevent_empty_text and text == '': + text = ' ' if skip_template: llama_text = text diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 863254ce7..79c0722a9 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -629,6 +629,10 @@ class UpscaleModel(ComfyTypeIO): if TYPE_CHECKING: Type = ImageModelDescriptor +@comfytype(io_type="LATENT_UPSCALE_MODEL") +class LatentUpscaleModel(ComfyTypeIO): + Type = Any + @comfytype(io_type="AUDIO") class Audio(ComfyTypeIO): class AudioDict(TypedDict): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index f7c34d059..5a2e8cc61 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -4,7 +4,8 @@ import torch import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +import folder_paths class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -57,6 +58,199 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): generate = execute # TODO: remove +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + generate = execute # TODO: remove + + +class HunyuanVideo15ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + encoded = vae.encode(start_image[:, :, :, :3]) + concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) + concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanVideo15SuperResolution(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15SuperResolution", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae", optional=True), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + in_latent = latent["samples"] + in_channels = in_latent.shape[1] + cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device()) + cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent + cond_latent[:, 2 * in_channels + 1] = 1 + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded + cond_latent[:, in_channels + 1, 0] = 1 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + return io.NodeOutput(positive, negative, latent) + + +class LatentUpscaleModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentUpscaleModelLoader", + display_name="Load Latent Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), + ], + outputs=[ + io.LatentUpscaleModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + if "blocks.0.block.0.conv.weight" in sd: + config = { + "in_channels": sd["in_conv.conv.weight"].shape[1], + "out_channels": sd["out_conv.conv.weight"].shape[0], + "hidden_channels": sd["in_conv.conv.weight"].shape[0], + "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]), + "global_residual": False, + } + model_type = "720p" + elif "up.0.block.0.conv1.conv.weight" in sd: + sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} + config = { + "z_channels": sd["conv_in.conv.weight"].shape[1], + "out_channels": sd["conv_out.conv.weight"].shape[0], + "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), + } + model_type = "1080p" + + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + + return io.NodeOutput(model) + + +class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15LatentUpscaleWithModel", + display_name="Hunyuan Video 15 Latent Upscale With Model", + category="latent", + inputs=[ + io.LatentUpscaleModel.Input("model"), + io.Latent.Input("samples"), + io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"), + io.Int.Input("width", default=1280, min=0, max=16384, step=8), + io.Int.Input("height", default=720, min=0, max=16384, step=8), + io.Combo.Input("crop", options=["disabled", "center"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput: + if width == 0 and height == 0: + return io.NodeOutput(samples) + else: + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) + else: + width = max(64, width) + height = max(64, height) + s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop) + s = model.resample_latent(s) + return io.NodeOutput({"samples": s.cpu().float()}) + + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " "1. The main content and theme of the video." @@ -210,6 +404,11 @@ class HunyuanExtension(ComfyExtension): CLIPTextEncodeHunyuanDiT, TextEncodeHunyuanVideo_ImageToVideo, EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, + HunyuanVideo15ImageToVideo, + HunyuanVideo15SuperResolution, + HunyuanVideo15LatentUpscaleWithModel, + LatentUpscaleModelLoader, HunyuanImageToVideo, EmptyHunyuanImageLatent, HunyuanRefinerLatent, diff --git a/folder_paths.py b/folder_paths.py index f110d832b..ffdc4d020 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -38,6 +38,8 @@ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], suppor folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +folder_names_and_paths["latent_upscale_models"] = ([os.path.join(models_dir, "latent_upscale_models")], supported_pt_extensions) + folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set()) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) diff --git a/models/latent_upscale_models/put_latent_upscale_models_here b/models/latent_upscale_models/put_latent_upscale_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 030371633..f023ae3b6 100644 --- a/nodes.py +++ b/nodes.py @@ -957,7 +957,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From 33981237527a3d84d4e9c3b113f75d6dd37af6a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:39:37 -0800 Subject: [PATCH 119/231] Fix wrong path. (#10821) --- comfy_extras/nodes_hunyuan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 5a2e8cc61..aa36a471f 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -185,7 +185,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: - model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) + model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) if "blocks.0.block.0.conv.weight" in sd: From c55fd7481626d8bee8044ea7512ea996d13a1b90 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Nov 2025 00:49:13 -0500 Subject: [PATCH 120/231] ComfyUI 0.3.71 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 9b77aabe9..b4655d553 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.70" +__version__ = "0.3.71" diff --git a/pyproject.toml b/pyproject.toml index 289b7145b..280dbaf53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.70" +version = "0.3.71" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From ecb683b057a19f1a05d18d6d0b0ee9a6c6c8f4a0 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 21 Nov 2025 13:34:47 -0800 Subject: [PATCH 121/231] update frontend to 1.30 (#10793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 624aa7362..f83d561c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.28.9 +comfyui-frontend-package==1.30.6 comfyui-workflow-templates==0.6.0 comfyui-embedded-docs==0.3.1 torch From 532938b16b544e4492ba0ffbe18b201b1a7bc55f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:51:55 -0800 Subject: [PATCH 122/231] --disable-api-nodes now sets CSP header to force frontend offline. (#10829) --- comfy/cli_args.py | 2 +- server.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2f30b72d2..d2b60e347 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -160,7 +160,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") -parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") +parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") diff --git a/server.py b/server.py index d9d5c491f..0fd2e49e3 100644 --- a/server.py +++ b/server.py @@ -164,6 +164,22 @@ def create_origin_only_middleware(): return origin_only_middleware + +def create_block_external_middleware(): + @web.middleware + async def block_external_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + return response + + return block_external_middleware + + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -193,6 +209,9 @@ class PromptServer(): else: middlewares.append(create_origin_only_middleware()) + if args.disable_api_nodes: + middlewares.append(create_block_external_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() From a9c35256bccd4018fbe74bf1e857cc18bd1900ed Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 22 Nov 2025 02:28:29 -0800 Subject: [PATCH 123/231] Update requirements.txt (#10834) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f83d561c9..8e308cd6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.30.6 -comfyui-workflow-templates==0.6.0 +comfyui-workflow-templates==0.7.9 comfyui-embedded-docs==0.3.1 torch torchsde From d89c29f25992713ec3102017c189858a457f1215 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 22 Nov 2025 19:51:53 -0800 Subject: [PATCH 124/231] Add display names to Hunyuan latent video nodes. (#10837) --- comfy_extras/nodes_hunyuan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index aa36a471f..32be182f1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -38,6 +38,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyHunyuanLatentVideo", + display_name="Empty HunyuanVideo 1.0 Latent", category="latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -63,6 +64,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): def define_schema(cls): schema = super().define_schema() schema.node_id = "EmptyHunyuanVideo15Latent" + schema.display_name = "Empty HunyuanVideo 1.5 Latent" return schema @classmethod @@ -71,8 +73,6 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent}) - generate = execute # TODO: remove - class HunyuanVideo15ImageToVideo(io.ComfyNode): @classmethod From cbd68e3d587a1b345bdc6ebcd8a8c6ba1a9d3af3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 23 Nov 2025 01:55:22 -0800 Subject: [PATCH 125/231] Add better error message for common error. (#10846) --- comfy/cldm/cldm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index ec01665e2..c93c2e909 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -413,7 +413,8 @@ class ControlNet(nn.Module): out_middle = [] if self.num_classes is not None: - assert y.shape[0] == x.shape[0] + if y is None: + raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?") emb = emb + self.label_emb(y) h = x From f66183a54142be693ab858e9f1f06ed62439a92e Mon Sep 17 00:00:00 2001 From: guill Date: Sun, 23 Nov 2025 22:56:20 -0800 Subject: [PATCH 126/231] [fix] Fixes non-async public API access (#10857) It looks like the synchronous version of the public API broke due to an addition of `from __future__ import annotations`. This change updates the async-to-sync adapter to work with both types of type annotations. --- comfy_api/internal/async_to_sync.py | 47 ++++++--- tests/execution/test_public_api.py | 153 ++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 16 deletions(-) create mode 100644 tests/execution/test_public_api.py diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py index f5f805a62..257ade82e 100644 --- a/comfy_api/internal/async_to_sync.py +++ b/comfy_api/internal/async_to_sync.py @@ -8,7 +8,7 @@ import os import textwrap import threading from enum import Enum -from typing import Optional, Type, get_origin, get_args +from typing import Optional, Type, get_origin, get_args, get_type_hints class TypeTracker: @@ -220,11 +220,18 @@ class AsyncToSyncConverter: self._async_instance = async_class(*args, **kwargs) # Handle annotated class attributes (like execution: Execution) - # Get all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Get all annotations from the class hierarchy and resolve string annotations + try: + # get_type_hints resolves string annotations to actual type objects + # This handles classes using 'from __future__ import annotations' + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations if get_type_hints fails + # (e.g., for undefined forward references) + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) # For each annotated attribute, check if it needs to be created or wrapped for attr_name, attr_type in all_annotations.items(): @@ -625,15 +632,19 @@ class AsyncToSyncConverter: """Extract class attributes that are classes themselves.""" class_attributes = [] + # Get resolved type hints to handle string annotations + try: + type_hints = get_type_hints(async_class) + except Exception: + type_hints = {} + # Look for class attributes that are classes for name, attr in sorted(inspect.getmembers(async_class)): if isinstance(attr, type) and not name.startswith("_"): class_attributes.append((name, attr)) - elif ( - hasattr(async_class, "__annotations__") - and name in async_class.__annotations__ - ): - annotation = async_class.__annotations__[name] + elif name in type_hints: + # Use resolved type hint instead of raw annotation + annotation = type_hints[name] if isinstance(annotation, type): class_attributes.append((name, annotation)) @@ -908,11 +919,15 @@ class AsyncToSyncConverter: attribute_mappings = {} # First check annotations for typed attributes (including from parent classes) - # Collect all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Resolve string annotations to actual types + try: + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) for attr_name, attr_type in sorted(all_annotations.items()): for class_name, class_type in class_attributes: diff --git a/tests/execution/test_public_api.py b/tests/execution/test_public_api.py new file mode 100644 index 000000000..52bc2fcd8 --- /dev/null +++ b/tests/execution/test_public_api.py @@ -0,0 +1,153 @@ +""" +Tests for public ComfyAPI and ComfyAPISync functions. + +These tests verify that the public API methods work correctly in both sync and async contexts, +ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly +handles string annotations from 'from __future__ import annotations'. +""" + +import pytest +import time +import subprocess +import torch +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +@pytest.mark.execution +class TestPublicAPI: + """Test suite for public ComfyAPI and ComfyAPISync methods.""" + + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start ComfyUI server for testing.""" + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + """Create shared client with connection retry.""" + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + break + except ConnectionRefusedError: + if i == n_tries - 1: + raise + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + """Set test name for each test.""" + shared_client.set_test_name(f"public_api[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + """Create GraphBuilder for each test.""" + yield GraphBuilder(prefix=request.node.name) + + def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestSyncProgressUpdate executes without errors. + + This test validates that api_sync.execution.set_progress() works correctly, + which is the primary code path fixed by adding get_type_hints() to async_to_sync.py. + """ + g = builder + image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + + # Use TestSyncProgressUpdate with short sleep + progress_node = g.node("TestSyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestAsyncProgressUpdate executes without errors. + + This test validates that await api.execution.set_progress() works correctly + in async contexts. + """ + g = builder + image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use TestAsyncProgressUpdate with short sleep + progress_node = g.node("TestAsyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Async progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder): + """Test both sync and async progress updates in same workflow. + + This test ensures that both ComfyAPISync and ComfyAPI can coexist and work + correctly in the same workflow execution. + """ + g = builder + image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use both types of progress nodes + sync_progress = g.node("TestSyncProgressUpdate", + value=image1.out(0), + sleep_seconds=0.3) + async_progress = g.node("TestAsyncProgressUpdate", + value=image2.out(0), + sleep_seconds=0.3) + + # Create outputs + output1 = g.node("SaveImage", images=sync_progress.out(0)) + output2 = g.node("SaveImage", images=async_progress.out(0)) + + # Execute workflow + result = client.run(g) + + # Both should execute successfully + assert result.did_run(sync_progress), "Sync progress node should have executed" + assert result.did_run(async_progress), "Async progress node should have executed" + assert result.did_run(output1), "First output node should have executed" + assert result.did_run(output2), "Second output node should have executed" + + # Verify outputs + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have produced 1 image from sync node" + assert len(images2) == 1, "Should have produced 1 image from async node" From 3bd71554a2df14b862cc5e1e875df37ba24af1ac Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:48:37 +0200 Subject: [PATCH 127/231] fix(api-nodes): edge cases in responses for Gemini models (#10860) --- comfy_api_nodes/apis/gemini_api.py | 6 +++--- comfy_api_nodes/nodes_gemini.py | 21 +++++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 710f173f1..d34590d28 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel): maxOutputTokens: int | None = Field(None, ge=16, le=8192) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) - temperature: float | None = Field(1, ge=0.0, le=2.0) - topK: int | None = Field(40, ge=1) - topP: float | None = Field(0.95, ge=0.0, le=1.0) + temperature: float | None = Field(None, ge=0.0, le=2.0) + topK: int | None = Field(None, ge=1) + topP: float | None = Field(None, ge=0.0, le=1.0) class GeminiImageConfig(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index be752c885..938a20f84 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera List of response parts matching the requested type. """ if response.candidates is None: - if response.promptFeedback.blockReason: + if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" ) - raise NotImplementedError( - "Gemini returned no response candidates. " - "Please report to ComfyUI repository with the example of workflow to reproduce this." + raise ValueError( + "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " + "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] for part in response.candidates[0].content.parts: @@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price - for i in response.usageMetadata.candidatesTokensDetails: - if i.modality == Modality.IMAGE: - final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models - else: - final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.candidatesTokensDetails: + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount if response.usageMetadata.thoughtsTokenCount: final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount return final_price / 1_000_000.0 @@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode): options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], default="auto", tooltip="If set to 'auto', matches your input image's aspect ratio; " - "if no image is provided, generates a 1:1 square.", + "if no image is provided, a 16:9 square is usually generated.", ), IO.Combo.Input( "resolution", From 1286fcfe40b98052e4edbe9a02f12ad89ac74924 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:24:29 +0200 Subject: [PATCH 128/231] add get_frame_count and get_frame_rate methods to VideoInput class (#10851) --- comfy_api/latest/_input/video_types.py | 28 ++++++++ comfy_api/latest/_input_impl/video_types.py | 72 +++++++++++++++++++++ comfy_api_nodes/nodes_topaz.py | 15 ++--- 3 files changed, 106 insertions(+), 9 deletions(-) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index a335df4d0..87c81d73a 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from fractions import Fraction from typing import Optional, Union, IO import io import av @@ -72,6 +73,33 @@ class VideoInput(ABC): frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video. + + Default implementation uses :meth:`get_components`, which may require + loading all frames into memory. File-based implementations should + override this method and use container/stream metadata instead. + + Returns: + Total number of frames as an integer. + """ + return int(self.get_components().images.shape[0]) + + def get_frame_rate(self) -> Fraction: + """ + Returns the frame rate of the video. + + Default implementation materializes the video into memory via + `get_components()`. Subclasses that can inspect the underlying + container (e.g. `VideoFromFile`) should override this with a more + efficient implementation. + + Returns: + Frame rate as a Fraction. + """ + return self.get_components().frame_rate + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index f646504c8..bde37f90a 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -121,6 +121,71 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video without materializing them as + torch tensors. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # 1. Prefer the frames field if available + if video_stream.frames and video_stream.frames > 0: + return int(video_stream.frames) + + # 2. Try to estimate from duration and average_rate using only metadata + if container.duration is not None and video_stream.average_rate: + duration_seconds = float(container.duration / av.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + if ( + getattr(video_stream, "duration", None) is not None + and getattr(video_stream, "time_base", None) is not None + and video_stream.average_rate + ): + duration_seconds = float(video_stream.duration * video_stream.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + # 3. Last resort: decode frames and count them (streaming) + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + + if frame_count == 0: + raise ValueError(f"Could not determine frame count for file '{self.__file}'") + return frame_count + + def get_frame_rate(self) -> Fraction: + """ + Returns the average frame rate of the video using container metadata + without decoding all frames. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # Preferred: use PyAV's average_rate (usually already a Fraction-like) + if video_stream.average_rate: + return Fraction(video_stream.average_rate) + + # Fallback: estimate from frames + duration if available + if video_stream.frames and container.duration: + duration_seconds = float(container.duration / av.time_base) + if duration_seconds > 0: + return Fraction(video_stream.frames / duration_seconds).limit_denominator() + + # Last resort: match get_components_internal default + return Fraction(1) + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). @@ -238,6 +303,13 @@ class VideoFromFile(VideoInput): packet.stream = stream_map[packet.stream] output_container.mux(packet) + def _get_first_video_stream(self, container: InputContainer): + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise ValueError(f"No video stream found in file '{self.__file}'") + return video_stream + + class VideoFromComponents(VideoInput): """ Class representing video input from tensors. diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 79c7bf43d..f522756e5 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,8 +5,7 @@ import aiohttp import torch from typing_extensions import override -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.util import ( ApiEndpoint, @@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, upscaler_enabled: bool, upscaler_model: str, upscaler_resolution: str, @@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if upscaler_enabled is False and interpolation_enabled is False: raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") - src_width, src_height = video.get_dimensions() - video_components = video.get_components() - src_frame_rate = int(video_components.frame_rate) - duration_sec = video.get_duration() - estimated_frames = int(duration_sec * src_frame_rate) validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() src_video_stream = video.get_stream_source() target_width = src_width target_height = src_height @@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode): container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), - frameCount=estimated_frames, + frameCount=video.get_frame_count(), frameRate=src_frame_rate, resolution=topaz_api.Resolution(width=src_width, height=src_height), ), From 3d1fdaf9f448b34e4eba68bfd8e8de373ec0d22d Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:30:40 +0800 Subject: [PATCH 129/231] block info (#10843) --- comfy/ldm/chroma/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 67bf70eb1..a72f8cc47 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -179,7 +179,10 @@ class Chroma(nn.Module): pe = self.pe_embedder(ids) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -222,7 +225,10 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: From 6a6d456c88723538e3d0e5e942f78109ece5b73d Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:38:38 +0800 Subject: [PATCH 130/231] block info (#10842) --- comfy/ldm/qwen_image/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 427ea19c1..8c75670cd 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From b2ef58e2b17e73ca8cd376a1cdc976518ebbc168 Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:40:09 +0800 Subject: [PATCH 131/231] block info (#10844) --- comfy/ldm/hunyuan_video/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f75c6e0e1..2749c53f5 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module): attn_mask = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module): img = torch.cat((img, txt), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} From 22a2644e57530ee40e13486ccd7c953b87072093 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:45:54 -0800 Subject: [PATCH 132/231] Bump transformers version in requirements.txt (#10869) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8e308cd6c..b7014f956 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ torchvision torchaudio numpy>=1.25.0 einops -transformers>=4.37.2 +transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 From 25022e0b0965975b35bcaf28b153184d60a4f9de Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 22:48:53 -0800 Subject: [PATCH 133/231] Cleanup and fix issues with text encoder quants. (#10872) --- comfy/model_patcher.py | 3 +- comfy/ops.py | 168 +++++++++--------- comfy/quant_ops.py | 12 ++ comfy/sd.py | 9 +- comfy/sd1_clip.py | 18 +- comfy/text_encoders/hunyuan_video.py | 3 + .../comfy_quant/test_mixed_precision.py | 17 +- 7 files changed, 128 insertions(+), 102 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cf1b0d441..6551ced5a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -231,7 +231,6 @@ class ModelPatcher: self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update @@ -286,7 +285,7 @@ class ModelPatcher: return self.model.lowvram_patch_counter def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..af185ec24 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -540,113 +540,115 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== from .quant_ops import QuantizedTensor, QUANT_ALGOS -class MixedPrecisionOps(disable_weight_init): - _layer_quant_config = {} - _compute_dtype = torch.bfloat16 - class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__() +def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _layer_quant_config = layer_quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm - self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) - else: - self.register_parameter("bias", None) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} - self.tensor_class = None + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) - def reset_parameters(self): - return None + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def reset_parameters(self): + return None - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): - manually_loaded_keys = [weight_key] + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") - if layer_name not in MixedPrecisionOps._layer_quant_config: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") + manually_loaded_keys = [weight_key] - qconfig = QUANT_ALGOS[quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - weight_scale_key = f"{prefix}weight_scale" - layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } - if layout_params['scale'] is not None: - manually_loaded_keys.append(weight_scale_key) + qconfig = QUANT_ALGOS[quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), - requires_grad=False - ) + weight_scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(weight_scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(weight_scale_key) - for param_name in qconfig["parameters"]: - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + requires_grad=False + ) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def _forward(self, input, weight, bias): - return torch.nn.functional.linear(input, weight, bias) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._forward(input, weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) - def forward(self, input, *args, **kwargs): - run_every_op() + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) - if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + def forward(self, input, *args, **kwargs): + run_every_op() + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return MixedPrecisionOps + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece..905b4729e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada..b6df0bd61 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -917,7 +917,12 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + if metadata is not None: + quant_metadata = metadata.get("_quantization_metadata", None) + if quant_metadata is not None: + sd["_quantization_metadata"] = quant_metadata + clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1142,6 +1147,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: + if "_quantization_metadata" in c: + c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7..8f509bab1 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -109,13 +109,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): operations = model_options.get("custom_operations", None) scaled_fp8 = None + quantization_metadata = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + layer_quant_config = None + if quantization_metadata is not None: + layer_quant_config = json.loads(quantization_metadata).get("layers", None) + + if layer_quant_config is not None: + operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 557094f49..0110517bb 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""): if scaled_fp8_key in state_dict: out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + if "_quantization_metadata" in state_dict: + out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + return out diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index f8d1fd04e..63361309f 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): """Test that model with no quantization works normally""" - # Configure no quantization - ops.MixedPrecisionOps._layer_quant_config = {} - # Create model - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops({})) # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) @@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict state_dict = { @@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False) From b6805429b9c2f3aa919035bea849ecd1de3ac8e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:48:20 -0800 Subject: [PATCH 134/231] Allow pinning quantized tensors. (#10873) --- comfy/model_management.py | 6 +++++- comfy/quant_ops.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a21df54b3..a9327ac80 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) +PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if type(tensor) is not torch.nn.parameter.Parameter: + if type(tensor).__name__ not in PINNING_ALLOWED_TYPES: return False if not is_device_cpu(tensor.device): @@ -1124,6 +1125,9 @@ def pin_memory(tensor): return False ptr = tensor.data_ptr() + if ptr == 0: + return False + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 905b4729e..e938144a7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor): new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self): + return self._qdata.is_contiguous() # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) From acfaa5c4a132e1c01bc9d94e76b0d667c899bfd1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:55:49 -0800 Subject: [PATCH 135/231] Don't try fp8 matrix mult in quantized ops if not supported by hardware. (#10874) --- comfy/ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index af185ec24..785aa1c9f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -646,11 +646,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) From 015a0599d08f1072155b9213d488b73e502fea3c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:23:19 -0800 Subject: [PATCH 136/231] I found a case where this is needed (#10875) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index e938144a7..0c16bcf8d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -405,8 +405,8 @@ class TensorCoreFP8Layout(QuantizedLayout): tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - # lp_amax = torch.finfo(dtype).max - # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { From 6b573ae0cb11000a0330a35d9e31917c22c874a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 07:50:19 -0800 Subject: [PATCH 137/231] Flux 2 (#10879) --- comfy/latent_formats.py | 9 +++ comfy/ldm/flux/layers.py | 90 +++++++++++++++++++-------- comfy/ldm/flux/model.py | 80 ++++++++++++++++++------ comfy/ldm/models/autoencoder.py | 42 +++++++++++++ comfy/model_base.py | 23 +++++-- comfy/model_detection.py | 50 +++++++++++---- comfy/sd.py | 26 +++++++- comfy/supported_models.py | 34 +++++++++- comfy/text_encoders/flux.py | 107 +++++++++++++++++++++++++++++++- comfy/text_encoders/llama.py | 31 +++++++++ comfy_extras/nodes_flux.py | 80 +++++++++++++++++++++++- nodes.py | 2 +- 12 files changed, 506 insertions(+), 68 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 204fc048d..e98c7d6d8 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -178,6 +178,15 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 23150a712..2472ab79c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module): class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -98,11 +98,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module): self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): @@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module): mlp_ratio: float = 4.0, qk_scale: float = None, modulation=True, + mlp_silu_act=False, + bias=True, dtype=None, device=None, operations=None @@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module): self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") if modulation: self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) else: @@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module): else: mod = vec - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) del qkv @@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f202..1a24e6d95 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation ) @dataclass @@ -33,6 +34,11 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 class Flux(nn.Module): @@ -58,13 +64,17 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -73,6 +83,9 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -81,13 +94,30 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -103,9 +133,6 @@ class Flux(nn.Module): attn_mask: Tensor = None, ) -> Tensor: - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -118,9 +145,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) @@ -177,6 +212,9 @@ class Flux(nn.Module): img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): @@ -207,7 +245,7 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -234,10 +272,10 @@ class Flux(nn.Module): h_offset += rope_options.get("shift_y", 0.0) w_offset += rope_options.get("shift_x", 0.0) - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -259,10 +297,10 @@ class Flux(nn.Module): h = 0 w = 0 index = 0 - ref_latents_method = kwargs.get("ref_latents_method", "offset") + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: if ref_latents_method == "index": - index += 1 + index += self.params.ref_index_scale h_offset = 0 w_offset = 0 elif ref_latents_method == "uxo": @@ -286,7 +324,11 @@ class Flux(nn.Module): img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.axes_dim) == 4: # Flux 2 + txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 611d36a1b..4f50810dc 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.modules.ema import LitEma import comfy.ops +from einops import rearrange +import comfy.model_management class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): @@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params @@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine): z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index e14b552c5..cad79ecbd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,12 +898,13 @@ class Flux(BaseModel): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -928,6 +929,16 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0131ca25a..b2ba1459d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["vec_in_dim"] = None + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 diff --git a/comfy/sd.py b/comfy/sd.py index b6df0bd61..14dd8944c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -356,7 +356,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32: + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -382,6 +382,17 @@ class VAE: self.upscale_ratio = 4 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + if 'post_quant_conv.weight' in sd: self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) else: @@ -940,6 +951,8 @@ class TEModel(Enum): QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -972,6 +985,13 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -1086,6 +1106,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2e64b85e8..8fe8e63f6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -741,6 +741,37 @@ class FluxSchnell(Flux): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1422,6 +1453,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] + models += [SVD_img2vid] diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d61ef6668..8dbbca16e 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,10 +1,13 @@ from comfy import sd1_clip import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip +import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast import torch import os +import json +import base64 class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -68,3 +71,105 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + try: + from transformers.integrations.mistral import MistralConverter + except ModuleNotFoundError: + from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + +def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Flux2TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index feb44bbb0..749ff581b 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -34,6 +34,28 @@ class Llama2Config: rope_scale = None final_norm: bool = True +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -465,6 +487,15 @@ class Llama2(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index ce1b2e89f..d9c4bba81 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -2,7 +2,10 @@ import node_helpers import comfy.utils from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import math +import nodes class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode): encode = execute # TODO: remove +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) class FluxGuidance(io.ComfyNode): @classmethod @@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): append = execute # TODO: remove +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + class FluxExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension): FluxDisableGuidance, FluxKontextImageScale, FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, ] diff --git a/nodes.py b/nodes.py index f023ae3b6..f4835c02e 100644 --- a/nodes.py +++ b/nodes.py @@ -929,7 +929,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From 5c7b08ca58f5412b3a814b374793cacdb5b5f0a7 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:09:07 +0200 Subject: [PATCH 138/231] [API Nodes] add Flux.2 Pro node (#10880) --- comfy_api_nodes/apis/bfl_api.py | 28 +++- comfy_api_nodes/nodes_bfl.py | 238 +++++++++++++++---------------- comfy_api_nodes/util/__init__.py | 2 + 3 files changed, 143 insertions(+), 125 deletions(-) diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0fc8c0607..d8d3557b3 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class Flux2ProGenerateRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + seed: int | None = Field(None) + prompt_upsampling: bool | None = Field(None) + input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + safety_tolerance: int | None = Field( + 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 + ) + output_format: str | None = Field( + "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." + ) + + class BFLFluxKontextProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for what you wannt to edit.') input_image: Optional[str] = Field(None, description='Image to edit in base64 format') @@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') + id: str = Field(..., description="The unique identifier for the generation task.") + polling_url: str = Field(..., description="URL to poll for the generation result.") + cost: float | None = Field(None, description="Price in cents") class BFLStatus(str, Enum): diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 1740fb377..8826dea0c 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,7 +1,7 @@ from inspect import cleandoc -from typing import Optional import torch +from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension @@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, BFLStatus, + Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, + get_number_of_images, poll_op, resize_mask_to_image, sync_op, @@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: Optional[torch.Tensor] = None, + image_prompt: torch.Tensor | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor] = None, + input_image: torch.Tensor | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and resolution. - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProImageNode", - display_name="Flux 1.1 [pro] Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Int.Input( - "width", - default=1024, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "height", - default=768, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - IO.Image.Input( - "image_prompt", - optional=True, - ), - # "image_prompt_strength": ( - # IO.FLOAT, - # { - # "default": 0.1, - # "min": 0.0, - # "max": 1.0, - # "step": 0.01, - # "tooltip": "Blend between the prompt and the image prompt.", - # }, - # ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - prompt: str, - prompt_upsampling, - width: int, - height: int, - seed=0, - image_prompt=None, - # image_prompt_strength=0.1, - ) -> IO.NodeOutput: - image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) - initial_response = await sync_op( - cls, - ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1/generate", - method="POST", - ), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - width=width, - height=height, - seed=seed, - image_prompt=image_prompt, - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) - - class FluxProExpandNode(IO.ComfyNode): """ Outpaints image based on prompt. @@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2ProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Flux2ProImageNode", + display_name="Flux.2 [pro] Image", + category="api node/image/BFL", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or edit", + ), + IO.Int.Input( + "width", + default=1024, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "height", + default=768, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", + ), + IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + width: int, + height: int, + seed: int, + prompt_upsampling: bool, + images: torch.Tensor | None = None, + ) -> IO.NodeOutput: + reference_images = {} + if images is not None: + if get_number_of_images(images) > 9: + raise ValueError("The current maximum number of supported images is 9.") + for image_index in range(images.shape[0]): + key_name = f"input_image_{image_index + 1}" if image_index else "input_image" + reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=Flux2ProGenerateRequest( + prompt=prompt, + width=width, + height=height, + seed=seed, + prompt_upsampling=prompt_upsampling, + **reference_images, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FluxProUltraImageNode, - # FluxProImageNode, FluxKontextProImageNode, FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + Flux2ProImageNode, ] diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 21013b591..80292fb3c 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -36,6 +36,7 @@ from .upload_helpers import ( upload_video_to_comfyapi, ) from .validation_utils import ( + get_image_dimensions, get_number_of_images, validate_aspect_ratio_string, validate_audio_duration, @@ -82,6 +83,7 @@ __all__ = [ "trim_video", "video_to_base64_string", # Validation utilities + "get_image_dimensions", "get_number_of_images", "validate_aspect_ratio_string", "validate_audio_duration", From af81cb962d9dd283ddb551962cc223b5a186a1ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 08:40:32 -0800 Subject: [PATCH 139/231] Add Flux 2 support to README. (#10882) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 28beec427..b9300ab07 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) + - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From 828b1b9953175b6df79459f417d1032869d0b46a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Nov 2025 12:40:58 -0500 Subject: [PATCH 140/231] ComfyUI version v0.3.72 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b4655d553..dac038c26 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.71" +__version__ = "0.3.72" diff --git a/pyproject.toml b/pyproject.toml index 280dbaf53..75df8fb7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.71" +version = "0.3.72" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From dff996ca39d86265bbabf15e666484e051f0b3d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:30:24 -0800 Subject: [PATCH 141/231] Fix crash. (#10885) --- comfy/text_encoders/flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 8dbbca16e..024504a5b 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -87,6 +87,7 @@ def load_mistral_tokenizer(data): vocab = {} max_vocab = mistral_vocab["config"]["default_vocab_size"] + max_vocab -= len(mistral_vocab["special_tokens"]) for w in mistral_vocab["vocab"]: r = w["rank"] From 18b79acba95d44b4ea00bbbfc1856bc71bd58841 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 26 Nov 2025 03:58:21 +0800 Subject: [PATCH 142/231] Update workflow templates to v0.7.20 (#10883) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b7014f956..5f20816d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.30.6 -comfyui-workflow-templates==0.7.9 +comfyui-workflow-templates==0.7.20 comfyui-embedded-docs==0.3.1 torch torchsde From d196a905bb379a6d800d0c13f9b4fdea3965311a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:58:39 -0800 Subject: [PATCH 143/231] Lower vram usage for flux 2 text encoder. (#10887) --- comfy/sd1_clip.py | 7 ++++--- comfy/text_encoders/flux.py | 4 ++-- comfy/text_encoders/llama.py | 12 +++++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8f509bab1..0fc9ab3db 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() - assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") @@ -164,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if self.layer == "all": + if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -266,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - if self.layer == "all": + if isinstance(self.layer, list): + intermediate_output = self.layer + elif self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 024504a5b..99f4812bb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -138,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer): return tokens class Mistral3_24BModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = {} num_layers = model_options.get("num_layers", None) if num_layers is not None: @@ -154,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): def encode_token_weights(self, token_weight_pairs): out, pooled, extra = super().encode_token_weights(token_weight_pairs) - out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1) + out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1) out = out.movedim(1, 2) out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 749ff581b..d47ed27bc 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -434,8 +434,12 @@ class Llama2_(nn.Module): intermediate = None all_intermediate = None + only_layers = None if intermediate_output is not None: - if intermediate_output == "all": + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": all_intermediate = [] intermediate_output = None elif intermediate_output < 0: @@ -443,7 +447,8 @@ class Llama2_(nn.Module): for i, layer in enumerate(self.layers): if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -457,7 +462,8 @@ class Llama2_(nn.Module): x = self.norm(x) if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) From 0c18842acbdf546883b08808dd9feea7605d7649 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Nov 2025 14:59:37 -0500 Subject: [PATCH 144/231] ComfyUI v0.3.73 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index dac038c26..f8818838e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.72" +__version__ = "0.3.73" diff --git a/pyproject.toml b/pyproject.toml index 75df8fb7c..7e4bac12d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.72" +version = "0.3.73" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From e9aae31fa241a6a63a368800146ea91629d4e8c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:41:45 -0800 Subject: [PATCH 145/231] Z Image model. (#10892) --- comfy/ldm/lumina/model.py | 219 +++++++------------- comfy/ldm/modules/diffusionmodules/mmdit.py | 6 +- comfy/model_base.py | 4 + comfy/model_detection.py | 29 ++- comfy/sd.py | 8 + comfy/text_encoders/llama.py | 31 +++ comfy/text_encoders/z_image.py | 48 +++++ 7 files changed, 196 insertions(+), 149 deletions(-) create mode 100644 comfy/text_encoders/z_image.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b4494a51d..c8643eb82 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension @@ -31,6 +32,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + out_bias: bool = False, operation_settings={}, ): """ @@ -59,7 +61,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -70,35 +72,6 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( self, x: torch.Tensor, @@ -134,8 +107,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: @@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module): norm_eps: float, qk_norm: bool, modulation=True, + z_image_modulation=False, + attn_out_bias=False, operation_settings={}, ) -> None: """ @@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( self, @@ -323,7 +308,7 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}): super().__init__() self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, @@ -340,10 +325,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -373,12 +363,16 @@ class NextDiT(nn.Module): n_heads: int = 32, n_kv_heads: Optional[int] = None, multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, + ffn_dim_multiplier: float = 4.0, norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, image_model=None, device=None, dtype=None, @@ -390,6 +384,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -411,6 +407,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -434,7 +431,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -457,18 +454,24 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -503,108 +506,42 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) - - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) + padded_img_mask = None for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): @@ -627,7 +564,7 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 42f406f1a..0dc8fe789 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size diff --git a/comfy/model_base.py b/comfy/model_base.py index cad79ecbd..cc21b1de9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1114,9 +1114,13 @@ class Lumina2(BaseModel): if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b2ba1459d..7afe4a798 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/sd.py b/comfy/sd.py index 14dd8944c..350fae92b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -52,6 +52,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image import comfy.model_patcher import comfy.lora @@ -953,6 +954,8 @@ class TEModel(Enum): GEMMA_3_4B = 13 MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -985,6 +988,8 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: @@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d47ed27bc..cd4b5f76c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -78,6 +78,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -511,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..bb9273b20 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,48 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_ From 0e24dbb19f34f242edb77c550396cf6806f7b22f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:02:51 -0800 Subject: [PATCH 146/231] Adjustments to Z Image. (#10893) --- comfy/supported_models.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8fe8e63f6..af8120400 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image from . import supported_models_base from . import latent_formats @@ -994,7 +995,7 @@ class Lumina2(supported_models_base.BASE): "shift": 6.0, } - memory_usage_factor = 1.2 + memory_usage_factor = 1.4 unet_extra_config = {} latent_format = latent_formats.Flux @@ -1013,6 +1014,24 @@ class Lumina2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) +class ZImage(Lumina2): + unet_config = { + "image_model": "lumina2", + "dim": 3840, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + memory_usage_factor = 1.7 + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1453,7 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] models += [SVD_img2vid] From bdb10a583f1b1e495ee00dbd1674f11016a6a93e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 21:07:58 -0800 Subject: [PATCH 147/231] Fix loras not working on mixed fp8. (#10899) --- comfy/model_patcher.py | 2 +- comfy/ops.py | 22 +++++++++++++++++++++- comfy/quant_ops.py | 21 ++++++++++++++------- comfy/weight_adapter/lora.py | 1 + 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6551ced5a..73adc7f70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -132,7 +132,7 @@ class LowVramPatch: def __call__(self, weight): intermediate_dtype = weight.dtype if self.convert_func is not None: - weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + weight = self.convert_func(weight, inplace=False) if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 diff --git a/comfy/ops.py b/comfy/ops.py index 785aa1c9f..a0ff4e8f1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if weight_has_function or weight.dtype != dtype: with wf_context: weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() for f in s.weight_function: weight = f(weight) @@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight else: - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) @@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) + + def convert_weight(self, weight, inplace=False, **kwargs): + if isinstance(weight, QuantizedTensor): + return weight.dequantize() + else: + return weight + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): + if getattr(self, 'layout_type', None) is not None: + weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + else: + weight = weight.to(self.weight.dtype) + if return_weight: + return weight + + assert inplace_update is False # TODO: eventually remove the inplace_update stuff + self.weight = torch.nn.Parameter(weight, requires_grad=False) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 0c16bcf8d..d2f3e7397 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,7 @@ import torch import logging from typing import Tuple, Dict +import comfy.float _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} @@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout): - orig_dtype: Original dtype before quantization (for casting back) """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype if scale is None: @@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) - # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) + + if stochastic_rounding > 0: + tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + else: + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) + tensor = tensor.to(dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, 'orig_dtype': orig_dtype } - return qdata, layout_params + return tensor, layout_params @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 4db004e50..3cc60bb1b 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase): lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) + del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, From 90b3995ec842335e44d70e0521ff6ff6c3ff9aaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 00:34:15 -0500 Subject: [PATCH 148/231] ComfyUI v0.3.74 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index f8818838e..b565c7367 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.73" +__version__ = "0.3.74" diff --git a/pyproject.toml b/pyproject.toml index 7e4bac12d..ccf0fcdb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.73" +version = "0.3.74" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 58b85746618e2bc2dd32024c89403926aad59f48 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 23:36:19 -0800 Subject: [PATCH 149/231] Fix Flux2 reference image mem estimation. (#10905) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index cc21b1de9..9b76c285e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -926,7 +926,7 @@ class Flux(BaseModel): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class Flux2(Flux): From 8402c8700a29a97bc5d706d6a0b14c41bc2c2d8a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 02:41:13 -0500 Subject: [PATCH 150/231] ComfyUI version v0.3.75 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b565c7367..fa4b4f4b0 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.74" +__version__ = "0.3.75" diff --git a/pyproject.toml b/pyproject.toml index ccf0fcdb9..9009e65fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.74" +version = "0.3.75" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f16219e3aadcb7a301a1a313ab8989c3ebe53764 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:00:43 -0800 Subject: [PATCH 151/231] Add cheap latent preview for flux 2. (#10907) Thank you to the person who calculated them. You saved me a percent of my time. --- comfy/latent_formats.py | 40 ++++++++++++++++++++++++++++++++++++++++ latent_preview.py | 7 +++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index e98c7d6d8..8e110f45d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -6,6 +6,7 @@ class LatentFormat: latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None + latent_rgb_factors_reshape = None taesd_decoder_name = None def process_in(self, latent): @@ -181,6 +182,45 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + def __init__(self): + self.latent_rgb_factors =[ + [0.0058, 0.0113, 0.0073], + [0.0495, 0.0443, 0.0836], + [-0.0099, 0.0096, 0.0644], + [0.2144, 0.3009, 0.3652], + [0.0166, -0.0039, -0.0054], + [0.0157, 0.0103, -0.0160], + [-0.0398, 0.0902, -0.0235], + [-0.0052, 0.0095, 0.0109], + [-0.3527, -0.2712, -0.1666], + [-0.0301, -0.0356, -0.0180], + [-0.0107, 0.0078, 0.0013], + [0.0746, 0.0090, -0.0941], + [0.0156, 0.0169, 0.0070], + [-0.0034, -0.0040, -0.0114], + [0.0032, 0.0181, 0.0080], + [-0.0939, -0.0008, 0.0186], + [0.0018, 0.0043, 0.0104], + [0.0284, 0.0056, -0.0127], + [-0.0024, -0.0022, -0.0030], + [0.1207, -0.0026, 0.0065], + [0.0128, 0.0101, 0.0142], + [0.0137, -0.0072, -0.0007], + [0.0095, 0.0092, -0.0059], + [0.0000, -0.0077, -0.0049], + [-0.0465, -0.0204, -0.0312], + [0.0095, 0.0012, -0.0066], + [0.0290, -0.0034, 0.0025], + [0.0220, 0.0169, -0.0048], + [-0.0332, -0.0457, -0.0468], + [-0.0085, 0.0389, 0.0609], + [-0.0076, 0.0003, -0.0043], + [-0.0111, -0.0460, -0.0614], + ] + + self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + def process_in(self, latent): return latent diff --git a/latent_preview.py b/latent_preview.py index 95d3cb733..ddf6dcf49 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = None if latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + self.latent_rgb_factors_reshape = latent_rgb_factors_reshape def decode_latent_to_preview(self, x0): + if self.latent_rgb_factors_reshape is not None: + x0 = self.latent_rgb_factors_reshape(x0) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) @@ -85,7 +88,7 @@ def get_previewer(device, latent_format): if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape) return previewer def prepare_callback(model, steps, x0_output_dict=None): From 8938aa3f3064415758fa8f3a628476535a676183 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:14:02 +0200 Subject: [PATCH 152/231] add Veo3 First-Last-Frame node (#10878) --- comfy_api_nodes/apis/veo_api.py | 38 +++----- comfy_api_nodes/nodes_veo2.py | 155 ++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 25 deletions(-) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index a55137afb..8328d1aa4 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -1,34 +1,21 @@ -from typing import Optional, Union -from enum import Enum +from typing import Optional from pydantic import BaseModel, Field -class Image2(BaseModel): - bytesBase64Encoded: str - gcsUri: Optional[str] = None - mimeType: Optional[str] = None +class VeoRequestInstanceImage(BaseModel): + bytesBase64Encoded: str | None = Field(None) + gcsUri: str | None = Field(None) + mimeType: str | None = Field(None) -class Image3(BaseModel): - bytesBase64Encoded: Optional[str] = None - gcsUri: str - mimeType: Optional[str] = None - - -class Instance1(BaseModel): - image: Optional[Union[Image2, Image3]] = Field( - None, description='Optional image to guide video generation' - ) +class VeoRequestInstance(BaseModel): + image: VeoRequestInstanceImage | None = Field(None) + lastFrame: VeoRequestInstanceImage | None = Field(None) prompt: str = Field(..., description='Text description of the video') -class PersonGeneration1(str, Enum): - ALLOW = 'ALLOW' - BLOCK = 'BLOCK' - - -class Parameters1(BaseModel): +class VeoRequestParameters(BaseModel): aspectRatio: Optional[str] = Field(None, examples=['16:9']) durationSeconds: Optional[int] = None enhancePrompt: Optional[bool] = None @@ -37,17 +24,18 @@ class Parameters1(BaseModel): description='Generate audio for the video. Only supported by veo 3 models.', ) negativePrompt: Optional[str] = None - personGeneration: Optional[PersonGeneration1] = None + personGeneration: str | None = Field(None, description="ALLOW or BLOCK") sampleCount: Optional[int] = None seed: Optional[int] = None storageUri: Optional[str] = Field( None, description='Optional Cloud Storage URI to upload the video' ) + resolution: str | None = Field(None) class VeoGenVidRequest(BaseModel): - instances: Optional[list[Instance1]] = None - parameters: Optional[Parameters1] = None + instances: list[VeoRequestInstance] | None = Field(None) + parameters: VeoRequestParameters | None = Field(None) class VeoGenVidResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index d37e9e9b4..a54dc13ab 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,6 +1,7 @@ import base64 from io import BytesIO +import torch from typing_extensions import override from comfy_api.input_impl.video_types import VideoFromFile @@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollResponse, VeoGenVidRequest, VeoGenVidResponse, + VeoRequestInstance, + VeoRequestInstanceImage, + VeoRequestParameters, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ) +class Veo3FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3FirstLastFrameNode", + display_name="Google Veo 3 First-Last-Frame to Video", + category="api node/video/Veo", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.Int.Input( + "duration", + default=8, + min=4, + max=8, + step=2, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation", + ), + IO.Image.Input("first_frame", tooltip="Start frame"), + IO.Image.Input("last_frame", tooltip="End frame"), + IO.Combo.Input( + "model", + options=["veo-3.1-generate", "veo-3.1-fast-generate"], + default="veo-3.1-fast-generate", + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Generate audio for the video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + model: str, + generate_audio: bool, + ): + model = MODELS_MAP[model] + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=[ + VeoRequestInstance( + prompt=prompt, + image=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png" + ), + lastFrame=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png" + ), + ), + ], + parameters=VeoRequestParameters( + aspectRatio=aspect_ratio, + personGeneration="ALLOW", + durationSeconds=duration, + enhancePrompt=True, # cannot be False for Veo3 + seed=seed, + generateAudio=generate_audio, + negativePrompt=negative_prompt, + resolution=resolution, + ), + ), + ) + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest( + operationName=initial_response.name, + ), + poll_interval=5.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + + class VeoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ VeoVideoGenerationNode, Veo3VideoGenerationNode, + Veo3FirstLastFrameNode, ] From 1105e0d139001ad602d0f883406bfce41e54ae67 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:23:14 +0200 Subject: [PATCH 153/231] improve UX for batch uploads in upload_images_to_comfyapi (#10913) --- comfy_api_nodes/util/upload_helpers.py | 43 +++++++++++++------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 632450d9b..b9019841f 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,7 +4,7 @@ import logging import time import uuid from io import BytesIO -from typing import Optional, Union +from typing import Optional from urllib.parse import urlparse import aiohttp @@ -48,8 +48,9 @@ async def upload_images_to_comfyapi( image: torch.Tensor, *, max_images: int = 8, - mime_type: Optional[str] = None, - wait_label: Optional[str] = "Uploading", + mime_type: str | None = None, + wait_label: str | None = "Uploading", + show_batch_index: bool = True, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -59,11 +60,18 @@ async def upload_images_to_comfyapi( download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 + num_to_upload = min(batch_len, max_images) + batch_start_ts = time.monotonic() - for idx in range(min(batch_len, max_images)): + for idx in range(num_to_upload): tensor = image[idx] if is_batch else image img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + + effective_label = wait_label + if wait_label and show_batch_index and num_to_upload > 1: + effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" + + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) download_urls.append(url) return download_urls @@ -126,8 +134,9 @@ async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, filename: str, - upload_mime_type: Optional[str], - wait_label: Optional[str] = "Uploading", + upload_mime_type: str | None, + wait_label: str | None = "Uploading", + progress_origin_ts: float | None = None, ) -> str: """Uploads a single file to ComfyUI API and returns its download URL.""" if upload_mime_type is None: @@ -148,6 +157,7 @@ async def upload_file_to_comfyapi( file_bytes_io, content_type=upload_mime_type, wait_label=wait_label, + progress_origin_ts=progress_origin_ts, ) return create_resp.download_url @@ -155,27 +165,18 @@ async def upload_file_to_comfyapi( async def upload_file( cls: type[IO.ComfyNode], upload_url: str, - file: Union[BytesIO, str], + file: BytesIO | str, *, - content_type: Optional[str] = None, + content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, - wait_label: Optional[str] = None, + wait_label: str | None = None, + progress_origin_ts: float | None = None, ) -> None: """ Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. - Args: - cls: Node class (provides auth context + UI progress hooks). - upload_url: Pre-signed PUT URL. - file: BytesIO or path string. - content_type: Explicit MIME type. If None, we *suppress* Content-Type. - max_retries: Maximum retry attempts. - retry_delay: Initial delay in seconds. - retry_backoff: Exponential backoff factor. - wait_label: Progress label shown in Comfy UI. - Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception """ @@ -198,7 +199,7 @@ async def upload_file( attempt = 0 delay = retry_delay - start_ts = time.monotonic() + start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() op_uuid = uuid.uuid4().hex[:8] while True: attempt += 1 From 8908ee262862f1252d1363d55c59872fb3361066 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 20:38:30 +0200 Subject: [PATCH 154/231] fix(gemini): use first 10 images as fileData (URLs) and remaining images as inline base64 (#10918) --- comfy_api_nodes/apis/gemini_api.py | 6 ++++ comfy_api_nodes/nodes_gemini.py | 55 ++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index d34590d28..a380ecc86 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel): mimeType: GeminiMimeType | None = Field(None) +class GeminiFileData(BaseModel): + fileUri: str | None = Field(None) + mimeType: GeminiMimeType | None = Field(None) + + class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) + fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 938a20f84..976d9c225 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -20,6 +20,7 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis.gemini_api import ( GeminiContent, + GeminiFileData, GeminiGenerateContentRequest, GeminiGenerateContentResponse, GeminiImageConfig, @@ -38,6 +39,7 @@ from comfy_api_nodes.util import ( get_number_of_images, sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, validate_string, video_to_base64_string, ) @@ -68,24 +70,43 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ +async def create_image_parts( + cls: type[IO.ComfyNode], + images: torch.Tensor, + image_limit: int = 0, +) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) + if image_limit < 0: + raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") + total_images = get_number_of_images(images) + if total_images <= 0: + raise ValueError("No images provided to create_image_parts; at least one image is required.") + + # If image_limit == 0 --> use all images; otherwise clamp to image_limit. + effective_max = total_images if image_limit == 0 else min(total_images, image_limit) + + # Number of images we'll send as URLs (fileData) + num_url_images = min(effective_max, 10) # Vertex API max number of image links + reference_images_urls = await upload_images_to_comfyapi( + cls, + images, + max_images=num_url_images, + ) + for reference_image_url in reference_images_urls: + image_parts.append( + GeminiPart( + fileData=GeminiFileData( + mimeType=GeminiMimeType.image_png, + fileUri=reference_image_url, + ) + ) + ) + for idx in range(num_url_images, effective_max): image_parts.append( GeminiPart( inlineData=GeminiInlineData( mimeType=GeminiMimeType.image_png, - data=image_as_b64, + data=tensor_to_base64_string(images[idx]), ) ) ) @@ -338,8 +359,7 @@ class GeminiNode(IO.ComfyNode): # Add other modal parts if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if audio is not None: parts.extend(cls.create_audio_parts(audio)) if video is not None: @@ -562,8 +582,7 @@ class GeminiImage(IO.ComfyNode): image_config = GeminiImageConfig(aspectRatio=aspect_ratio) if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) @@ -702,7 +721,7 @@ class GeminiImage2(IO.ComfyNode): if images is not None: if get_number_of_images(images) > 14: raise ValueError("The current maximum number of supported images is 14.") - parts.extend(create_image_parts(images)) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) From 234c3dc85f7e61a537bbf6d8999c5880c5e0b746 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 26 Nov 2025 11:58:08 -0800 Subject: [PATCH 155/231] Bump frontend to 1.32.9 (#10867) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f20816d6..9291552d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.30.6 +comfyui-frontend-package==1.32.9 comfyui-workflow-templates==0.7.20 comfyui-embedded-docs==0.3.1 torch From 58c6ed541d5aaf6d9b12f63bc23c33164e1cf7a3 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Wed, 26 Nov 2025 14:58:27 -0500 Subject: [PATCH 156/231] Merge 3d animation node (#10025) --- comfy_extras/nodes_load_3d.py | 110 +++++++--------------------------- 1 file changed, 23 insertions(+), 87 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 899608149..54c66ef68 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile from pathlib import Path +from PIL import Image +import numpy as np + +import uuid def normalize_path(path): return path.replace('\\', '/') @@ -34,58 +38,6 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): - image_path = folder_paths.get_annotated_filepath(image['image']) - mask_path = folder_paths.get_annotated_filepath(image['mask']) - normal_path = folder_paths.get_annotated_filepath(image['normal']) - lineart_path = folder_paths.get_annotated_filepath(image['lineart']) - - load_image_node = nodes.LoadImage() - output_image, ignore_mask = load_image_node.load_image(image=image_path) - ignore_image, output_mask = load_image_node.load_image(image=mask_path) - normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - - video = None - - if image['recording'] != "": - recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - - video = VideoFromFile(recording_video_path) - - return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video - -class Load3DAnimation(): - @classmethod - def INPUT_TYPES(s): - input_dir = os.path.join(folder_paths.get_input_directory(), "3d") - - os.makedirs(input_dir, exist_ok=True) - - input_path = Path(input_dir) - base_path = Path(folder_paths.get_input_directory()) - - files = [ - normalize_path(str(file_path.relative_to(base_path))) - for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} - ] - - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D_ANIMATION", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") @@ -120,7 +72,8 @@ class Preview3D(): "model_file": ("STRING", {"default": "", "multiline": False}), }, "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) + "camera_info": ("LOAD3D_CAMERA", {}), + "bg_image": ("IMAGE", {}) }} OUTPUT_NODE = True @@ -133,50 +86,33 @@ class Preview3D(): def process(self, model_file, **kwargs): camera_info = kwargs.get("camera_info", None) + bg_image = kwargs.get("bg_image", None) + + bg_image_path = None + if bg_image is not None: + + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_array) + + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + + bg_image_path = f"temp/{filename}" return { "ui": { - "result": [model_file, camera_info] - } - } - -class Preview3DAnimation(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] + "result": [model_file, camera_info, bg_image_path] } } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, - "Load3DAnimation": Load3DAnimation, "Preview3D": Preview3D, - "Preview3DAnimation": Preview3DAnimation } NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D", - "Load3DAnimation": "Load 3D - Animation", - "Preview3D": "Preview 3D", - "Preview3DAnimation": "Preview 3D - Animation" + "Load3D": "Load 3D & Animation", + "Preview3D": "Preview 3D & Animation", } From 55f654db3ddaf5a10ac6dbe79774c23c350d279d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:16:40 -0800 Subject: [PATCH 157/231] Fix the CSP offline feature. (#10923) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 0fd2e49e3..fca5050bd 100644 --- a/server.py +++ b/server.py @@ -174,7 +174,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" return response return block_external_middleware From dd41b745497cdbbafb0bd745f590726b0e41f9f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:36:38 -0800 Subject: [PATCH 158/231] Add Z Image to readme. (#10924) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b9300ab07..91fb510e1 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) + - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From d8433c63fdacef24f40da401b02ebba272bf1fbb Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:42:01 +0200 Subject: [PATCH 159/231] chore(api-nodes): remove chat widgets from OpenAI/Gemini nodes (#10861) --- comfy_api_nodes/nodes_gemini.py | 77 +-------------------------------- comfy_api_nodes/nodes_openai.py | 46 ++++---------------- 2 files changed, 11 insertions(+), 112 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 976d9c225..08f7b0f64 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer """ import base64 -import json import os -import time -import uuid from enum import Enum from io import BytesIO from typing import Literal @@ -43,7 +40,6 @@ from comfy_api_nodes.util import ( validate_string, video_to_base64_string, ) -from server import PromptServer GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB @@ -384,29 +380,6 @@ class GeminiNode(IO.ComfyNode): ) output_text = get_text_from_response(response) - if output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text or "Empty response from Gemini model...") @@ -601,30 +574,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -744,30 +694,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index acf35d276..c8da5464b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,15 +1,10 @@ from io import BytesIO -from typing import Optional, Union -import json import os -import time -import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from server import PromptServer import folder_paths import base64 from comfy_api.latest import IO, ComfyExtension @@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode): def create_input_message_contents( cls, prompt: str, - image: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, + image: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ + content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, persist_context: bool = False, model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, - images: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, - advanced_options: Optional[CreateModelResponseProperties] = None, + images: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, + advanced_options: CreateModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode): status_extractor=lambda response: response.status, completed_statuses=["incomplete", "completed"] ) - output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) - - # Update history - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text) + return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) class OpenAIInputFiles(IO.ComfyNode): @@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode): def execute( cls, truncation: bool, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, + instructions: str | None = None, + max_output_tokens: int | None = None, ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. From a2d60aad0f8e03657d501842460123f6eaaf6791 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:55:31 +0200 Subject: [PATCH 160/231] convert nodes_customer_sampler.py to V3 schema (#10206) --- comfy_extras/nodes_custom_sampler.py | 1182 ++++++++++++++------------ 1 file changed, 633 insertions(+), 549 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d011f433b..fbb080886 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -3,272 +3,312 @@ import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sa_solver -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import latent_preview import torch import comfy.utils import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class BasicScheduler: +class BasicScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, scheduler, steps, denoise): + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = int(steps/denoise) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class KarrasScheduler: +class KarrasScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExponentialScheduler: + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min): + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) - return (sigmas, ) + return io.NodeOutput(sigmas) -class PolyexponentialScheduler: + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class LaplaceScheduler: + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}), - "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class SDTurboScheduler: +class SDTurboScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 1, "min": 1, "max": 10}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, denoise): + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - return (sigmas, ) + return io.NodeOutput(sigmas) -class BetaSamplingScheduler: + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, alpha, beta): + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) -class VPScheduler: + get_sigmas = execute + +class VPScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, beta_d, beta_min, eps_s): + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) - return (sigmas, ) + return io.NodeOutput(sigmas) -class SplitSigmas: + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "step": ("INT", {"default": 0, "min": 0, "max": 10000}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, step): + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: sigmas1 = sigmas[:step + 1] sigmas2 = sigmas[step:] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class SplitSigmasDenoise: + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, denoise): + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: steps = max(sigmas.shape[-1] - 1, 0) total_steps = round(steps * denoise) sigmas1 = sigmas[:-(total_steps)] sigmas2 = sigmas[-(total_steps + 1):] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class FlipSigmas: + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas): + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: if len(sigmas) == 0: - return (sigmas,) + return io.NodeOutput(sigmas) sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 - return (sigmas,) + return io.NodeOutput(sigmas) -class SetFirstSigma: + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "set_first_sigma" - - def set_first_sigma(self, sigmas, sigma): + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: sigmas = sigmas.clone() sigmas[0] = sigma - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExtendIntermediateSigmas: + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "steps": ("INT", {"default": 2, "min": 1, "max": 100}), - "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}), - "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}), - "spacing": (['linear', 'cosine', 'sine'],), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "extend" - - def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: if start_at_sigma < 0: start_at_sigma = float("inf") @@ -299,27 +339,27 @@ class ExtendIntermediateSigmas: extended_sigmas = torch.FloatTensor(extended_sigmas) - return (extended_sigmas,) + return io.NodeOutput(extended_sigmas) + + extend = execute -class SamplingPercentToSigma: +class SamplingPercentToSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}), - "return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) - RETURN_TYPES = (IO.FLOAT,) - RETURN_NAMES = ("sigma_value",) - CATEGORY = "sampling/custom_sampling/sigmas" - - FUNCTION = "get_sigma" - - def get_sigma(self, model, sampling_percent, return_actual_sigma): + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_val = model_sampling.percent_to_sigma(sampling_percent) if return_actual_sigma: @@ -327,212 +367,234 @@ class SamplingPercentToSigma: sigma_val = model_sampling.sigma_max.item() elif sampling_percent == 1.0: sigma_val = model_sampling.sigma_min.item() - return (sigma_val,) + return io.NodeOutput(sigma_val) + + get_sigma = execute -class KSamplerSelect: +class KSamplerSelect(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, sampler_name): + @classmethod + def execute(cls, sampler_name) -> io.NodeOutput: sampler = comfy.samplers.sampler_object(sampler_name) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_3M_SDE: + get_sampler = execute + +class SamplerDPMPP_3M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, noise_device): + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_3m_sde" else: sampler_name = "dpmpp_3m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2M_SDE: + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"solver_type": (['midpoint', 'heun'], ), - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, eta, s_noise, noise_device): + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerDPMPP_SDE: +class SamplerDPMPP_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, r, noise_device): + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2S_Ancestral: + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestral: + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestralCFGPP: + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), - }} - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler( "euler_ancestral_cfg_pp", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerLMS: + get_sampler = execute + +class SamplerLMS(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 4, "min": 1, "max": 100}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order): + @classmethod + def execute(cls, order) -> io.NodeOutput: sampler = comfy.samplers.ksampler("lms", {"order": order}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMAdaptative: + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 3, "min": 2, "max": 3}), - "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerER_SDE(ComfyNodeABC): +class SamplerER_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), - "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), - "eta": ( - IO.FLOAT, - {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, - ), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, max_stage, eta, s_noise): + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): eta = 0 s_noise = 0 @@ -548,32 +610,33 @@ class SamplerER_SDE(ComfyNodeABC): sampler_name = "er_sde" sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerSASolver(ComfyNodeABC): +class SamplerSASolver(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},), - "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},), - "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},), - "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}), - "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}), - "use_pece": (IO.BOOLEAN, {}), - "simple_order_2": (IO.BOOLEAN, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") start_sigma = model_sampling.percent_to_sigma(sde_start_percent) end_sigma = model_sampling.percent_to_sigma(sde_end_percent) @@ -591,7 +654,9 @@ class SamplerSASolver(ComfyNodeABC): "simple_order_2": simple_order_2, }, ) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute class Noise_EmptyNoise: @@ -612,30 +677,31 @@ class Noise_RandomNoise: batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) -class SamplerCustom: +class SamplerCustom(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") - - FUNCTION = "sample" - - CATEGORY = "sampling/custom_sampling" - - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -664,52 +730,58 @@ class SamplerCustom: out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + + sample = execute class Guider_Basic(comfy.samplers.CFGGuider): def set_conds(self, positive): self.inner_set_conds({"positive": positive}) -class BasicGuider: +class BasicGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "conditioning": ("CONDITIONING", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, conditioning): + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: guider = Guider_Basic(model) guider.set_conds(conditioning) - return (guider,) + return io.NodeOutput(guider) -class CFGGuider: + get_guider = execute + +class CFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, positive, negative, cfg): + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: guider = comfy.samplers.CFGGuider(model) guider.set_conds(positive, negative) guider.set_cfg(cfg) - return (guider,) + return io.NodeOutput(guider) + + get_guider = execute class Guider_DualCFG(comfy.samplers.CFGGuider): def set_cfg(self, cfg1, cfg2, nested=False): @@ -740,84 +812,88 @@ class Guider_DualCFG(comfy.samplers.CFGGuider): out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 -class DualCFGGuider: +class DualCFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "cond1": ("CONDITIONING", ), - "cond2": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "style": (["regular", "nested"],), - } - } + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) - return (guider,) + return io.NodeOutput(guider) -class DisableNoise: + get_guider = execute + +class DisableNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required":{ - } - } + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("NOISE",) - FUNCTION = "get_noise" - CATEGORY = "sampling/custom_sampling/noise" - - def get_noise(self): - return (Noise_EmptyNoise(),) - - -class RandomNoise(DisableNoise): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "noise_seed": ("INT", { - "default": 0, - "min": 0, - "max": 0xffffffffffffffff, - "control_after_generate": True, - }), - } - } + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) - def get_noise(self, noise_seed): - return (Noise_RandomNoise(noise_seed),) + get_noise = execute -class SamplerCustomAdvanced: +class RandomNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"noise": ("NOISE", ), - "guider": ("GUIDER", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) - FUNCTION = "sample" + get_noise = execute - CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -842,28 +918,32 @@ class SamplerCustomAdvanced: out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) -class AddNoise: + sample = execute + +class AddNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "noise": ("NOISE", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - - FUNCTION = "add_noise" - - CATEGORY = "_for_testing/custom_sampling/noise" - - def add_noise(self, model, noise, sigmas, latent_image): + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: if len(sigmas) == 0: - return latent_image + return io.NodeOutput(latent_image) latent = latent_image latent_image = latent["samples"] @@ -887,46 +967,50 @@ class AddNoise: out = latent.copy() out["samples"] = noisy - return (out,) + return io.NodeOutput(out) + + add_noise = execute -NODE_CLASS_MAPPINGS = { - "SamplerCustom": SamplerCustom, - "BasicScheduler": BasicScheduler, - "KarrasScheduler": KarrasScheduler, - "ExponentialScheduler": ExponentialScheduler, - "PolyexponentialScheduler": PolyexponentialScheduler, - "LaplaceScheduler": LaplaceScheduler, - "VPScheduler": VPScheduler, - "BetaSamplingScheduler": BetaSamplingScheduler, - "SDTurboScheduler": SDTurboScheduler, - "KSamplerSelect": KSamplerSelect, - "SamplerEulerAncestral": SamplerEulerAncestral, - "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, - "SamplerLMS": SamplerLMS, - "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, - "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, - "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, - "SamplerDPMAdaptative": SamplerDPMAdaptative, - "SamplerER_SDE": SamplerER_SDE, - "SamplerSASolver": SamplerSASolver, - "SplitSigmas": SplitSigmas, - "SplitSigmasDenoise": SplitSigmasDenoise, - "FlipSigmas": FlipSigmas, - "SetFirstSigma": SetFirstSigma, - "ExtendIntermediateSigmas": ExtendIntermediateSigmas, - "SamplingPercentToSigma": SamplingPercentToSigma, +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ] - "CFGGuider": CFGGuider, - "DualCFGGuider": DualCFGGuider, - "BasicGuider": BasicGuider, - "RandomNoise": RandomNoise, - "DisableNoise": DisableNoise, - "AddNoise": AddNoise, - "SamplerCustomAdvanced": SamplerCustomAdvanced, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", -} +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() From cc6a8dcd1ad9cc9ef7602ee141174a0cea0ed4ce Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 27 Nov 2025 08:18:08 +0800 Subject: [PATCH 161/231] Dataset Processing Nodes and Improved LoRA Trainer Nodes with multi resolution supports. (#10708) * Create nodes_dataset.py * Add encoded dataset caching mechanism * make training node to work with our dataset system * allow trainer node to get different resolution dataset * move all dataset related implementation to nodes_dataset * Rewrite dataset system with new io schema * Rewrite training system with new io schema * add ui pbar * Add outputs' id/name * Fix bad id/naming * use single process instead of input list when no need * fix wrong output_list flag * use torch.load/save and fix bad behaviors --- comfy_extras/nodes_dataset.py | 1532 +++++++++++++++++++++++++++++++++ comfy_extras/nodes_train.py | 967 ++++++++++----------- nodes.py | 1 + 3 files changed, 1980 insertions(+), 520 deletions(-) create mode 100644 comfy_extras/nodes_dataset.py diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..b23867505 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1532 @@ +import logging +import os +import math +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesToSameSizeNode(ImageProcessingNode): + node_id = "ResizeImagesToSameSize" + display_name = "Resize Images to Same Size" + description = "Resize all images to the same width and height." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), + io.Combo.Input( + "mode", + options=["stretch", "crop_center", "pad"], + default="stretch", + tooltip="Resize mode.", + ), + ] + + @classmethod + def _process(cls, image, width, height, mode): + img = tensor_to_pil(image) + + if mode == "stretch": + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "crop_center": + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + if img.width != width or img.height != height: + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "pad": + img.thumbnail((width, height), Image.Resampling.LANCZOS) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + paste_x = (width - img.width) // 2 + paste_y = (height - img.height) // 2 + new_img.paste(img, (paste_x, paste_y)) + img = new_img + + return pil_to_tensor(img) + + +class ResizeImagesToPixelCountNode(ImageProcessingNode): + node_id = "ResizeImagesToPixelCount" + display_name = "Resize Images to Pixel Count" + description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "pixel_count", + default=512 * 512, + min=1, + max=8192 * 8192, + tooltip="Target pixel count.", + ), + io.Int.Input( + "steps", + default=64, + min=1, + max=128, + tooltip="The stepping for resize width/height.", + ), + ] + + @classmethod + def _process(cls, image, pixel_count, steps): + img = tensor_to_pil(image) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(w * pixel_count_ratio / steps) * steps + new_h = int(h * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class FlipImagesNode(ImageProcessingNode): + node_id = "FlipImages" + display_name = "Flip Images" + description = "Flip all images horizontally or vertically." + extra_inputs = [ + io.Combo.Input( + "direction", + options=["horizontal", "vertical"], + default="horizontal", + tooltip="Flip direction.", + ), + ] + + @classmethod + def _process(cls, image, direction): + img = tensor_to_pil(image) + if direction == "horizontal": + img = img.transpose(Image.FLIP_LEFT_RIGHT) + else: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesToSameSizeNode, + ResizeImagesToPixelCountNode, + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + FlipImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9e6ec6780..cb24ab709 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,15 +1,13 @@ -import datetime -import json import logging import os import numpy as np import safetensors import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange +from PIL import Image, ImageDraw, ImageFont +from typing_extensions import override import comfy.samplers import comfy.sd @@ -18,9 +16,9 @@ import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters, adapter_maps +from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -56,7 +54,18 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -65,54 +74,138 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() - for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) - indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not comfy.utils.PROGRESS_BAR_ENABLED, ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False + ): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( + self.seed + i * 1000 ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) - if (i+1) % self.grad_acc == 0: + if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -134,233 +227,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - -class LoadImageTextSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - "width": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The width to resize the images to. -1 means use the original width.", - }, - ), - "height": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The height to resize the images to. -1 means use the original height.", - }, - ) - }, - } - - RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder, clip, resize_method, width=None, height=None): - if clip is None: - raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - - logging.info(f"Loading images from folder: {folder}") - - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - - image_files = [] - for item in os.listdir(sub_input_dir): - path = os.path.join(sub_input_dir, item) - if any(item.lower().endswith(ext) for ext in valid_extensions): - image_files.append(path) - elif os.path.isdir(path): - # Support kohya-ss/sd-scripts folder structure - repeat = 1 - if item.split("_")[0].isdigit(): - repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) - - caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files - ] - captions = [] - for caption_file in caption_file_path: - caption_path = os.path.join(sub_input_dir, caption_file) - if os.path.exists(caption_path): - with open(caption_path, "r", encoding="utf-8") as f: - caption = f.read().strip() - captions.append(caption) - else: - captions.append("") - - width = width if width != -1 else None - height = height if height != -1 else None - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) - - logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - - logging.info(f"Encoding captions from {sub_input_dir}.") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) - for text in captions: - if text == "": - conditions.append(empty_cond) - tokens = clip.tokenize(text) - conditions.extend(clip.encode_from_tokens_scheduled(tokens)) - logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") - return (output_tensor, conditions) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -379,10 +245,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -396,12 +266,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -412,130 +283,126 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "grad_accumulation_steps": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 1024, - "step": 1, - "tooltip": "The number of gradient accumulation steps to use for training.", - } + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - "algorithm": ( - list(adapter_maps.keys()), - {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", ), - "gradient_checkpointing": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Use gradient checkpointing for training.", - } + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), - }, - } + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, + @classmethod + def execute( + cls, model, latents, positive, @@ -553,13 +420,74 @@ class TrainLoraNode: gradient_checkpointing, existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images @@ -591,9 +519,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -605,7 +531,9 @@ class TrainLoraNode: adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -629,7 +557,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -657,24 +587,31 @@ class TrainLoraNode: # setup models if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): patch(m) mp.model.requires_grad_(False) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) + train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, - total_steps=steps*grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input @@ -684,12 +621,15 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, - seed=noise.seed + seed=noise.seed, ) finally: for m in mp.model.modules(): @@ -702,111 +642,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -849,47 +796,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/nodes.py b/nodes.py index f4835c02e..bf73eb90e 100644 --- a/nodes.py +++ b/nodes.py @@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", From eaf68c9b5bbfbcdac8988741f3948678c9465c1d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:25:32 -0800 Subject: [PATCH 162/231] Make lora training work on Z Image and remove some redundant nodes. (#10927) --- comfy/ldm/lumina/model.py | 4 +- comfy_extras/nodes_dataset.py | 102 +--------------------------------- 2 files changed, 3 insertions(+), 103 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c8643eb82..565400b54 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -509,7 +509,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 @@ -525,7 +525,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index b23867505..4789d7d53 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,5 @@ import logging import os -import math import json import numpy as np @@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== -class ResizeImagesToSameSizeNode(ImageProcessingNode): - node_id = "ResizeImagesToSameSize" - display_name = "Resize Images to Same Size" - description = "Resize all images to the same width and height." - extra_inputs = [ - io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), - io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), - io.Combo.Input( - "mode", - options=["stretch", "crop_center", "pad"], - default="stretch", - tooltip="Resize mode.", - ), - ] - - @classmethod - def _process(cls, image, width, height, mode): - img = tensor_to_pil(image) - - if mode == "stretch": - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img - - return pil_to_tensor(img) - - -class ResizeImagesToPixelCountNode(ImageProcessingNode): - node_id = "ResizeImagesToPixelCount" - display_name = "Resize Images to Pixel Count" - description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." - extra_inputs = [ - io.Int.Input( - "pixel_count", - default=512 * 512, - min=1, - max=8192 * 8192, - tooltip="Target pixel count.", - ), - io.Int.Input( - "steps", - default=64, - min=1, - max=128, - tooltip="The stepping for resize width/height.", - ), - ] - - @classmethod - def _process(cls, image, pixel_count, steps): - img = tensor_to_pil(image) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(w * pixel_count_ratio / steps) * steps - new_h = int(h * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) - - class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode): return pil_to_tensor(img) -class FlipImagesNode(ImageProcessingNode): - node_id = "FlipImages" - display_name = "Flip Images" - description = "Flip all images horizontally or vertically." - extra_inputs = [ - io.Combo.Input( - "direction", - options=["horizontal", "vertical"], - default="horizontal", - tooltip="Flip direction.", - ), - ] - - @classmethod - def _process(cls, image, direction): - img = tensor_to_pil(image) - if direction == "horizontal": - img = img.transpose(Image.FLIP_LEFT_RIGHT) - else: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - return pil_to_tensor(img) - - class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" @@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes - ResizeImagesToSameSizeNode, - ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, RandomCropImagesNode, - FlipImagesNode, NormalizeImagesNode, AdjustBrightnessNode, AdjustContrastNode, From c38e7d6599be1bdce580ccfdbb20b928315af05e Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:28:44 +0800 Subject: [PATCH 163/231] block info (#10841) --- comfy/ldm/flux/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1a24e6d95..d5674dea6 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -171,7 +171,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -215,7 +218,10 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} From f17251bec65b5760cfedec29eace7d77f4b35130 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:03:03 +1000 Subject: [PATCH 164/231] Account for the VRAM cost of weight offloading (#10733) * mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model. --- comfy/model_management.py | 6 ++-- comfy/model_patcher.py | 59 +++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80..9c403d580 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,7 +1012,7 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 +NUM_STREAMS = 0 if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) @@ -1030,7 +1030,7 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: return None if device in STREAMS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): From 3f382a4f9884f7b672557028adb9bb85d075820d Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 28 Nov 2025 02:06:30 +1000 Subject: [PATCH 165/231] quant ops: Dequantize weight in-place (#10935) In flux2 these weights are huge (200MB). As plain_tensor is a throw-away deep copy, do this multiplication in-place to save VRAM. --- comfy/quant_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d2f3e7397..9b924560b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout): @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - return plain_tensor * scale + plain_tensor.mul_(scale) + return plain_tensor @classmethod def get_plain_tensors(cls, qtensor): From b59750a86a4687056528f1d59470e207063a73a3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 28 Nov 2025 06:12:56 +0800 Subject: [PATCH 166/231] Update template to 0.7.23 (#10949) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9291552d3..e0b2c566b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.32.9 -comfyui-workflow-templates==0.7.20 +comfyui-workflow-templates==0.7.23 comfyui-embedded-docs==0.3.1 torch torchsde From 9d8a817985bb069685e440b38762f95dc834d242 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:46:12 -0800 Subject: [PATCH 167/231] Enable async offloading by default on Nvidia. (#10953) Add --disable-async-offload to disable it. If this causes OOMs that go away when you --disable-async-offload please report it. --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d2b60e347..5f0dfaa10 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -131,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") -parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") +parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c403d580..38c506df5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,8 +1013,17 @@ def force_channels_last(): STREAMS = {} NUM_STREAMS = 0 -if args.async_offload: - NUM_STREAMS = 2 +if args.async_offload is not None: + NUM_STREAMS = args.async_offload +else: + # Enable by default on Nvidia + if is_nvidia(): + NUM_STREAMS = 2 + +if args.disable_async_offload: + NUM_STREAMS = 0 + +if NUM_STREAMS > 0: logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) def current_stream(device): From 52e778fff3c1d6f32c8d14cba9864faddba8475d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 28 Nov 2025 12:52:59 +0200 Subject: [PATCH 168/231] feat(Kling-API-Nodes): add v2-5-turbo model to FirstLastFrame node (#10938) --- comfy_api_nodes/nodes_kling.py | 60 +++++++++++++++------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 36852038b..23a7f55f1 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,8 +4,6 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -from __future__ import annotations -from typing import Optional, TypeVar import math import logging @@ -66,9 +64,7 @@ from comfy_api_nodes.util import ( poll_op, ) from comfy_api.input_impl import VideoFromFile -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Input KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -94,8 +90,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 -R = TypeVar("R") - MODE_TEXT2VIDEO = { "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), @@ -130,6 +124,8 @@ MODE_START_END_FRAME = { "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), } """ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. @@ -296,7 +292,7 @@ def get_video_from_response(response) -> KlingVideoResult: return video -def get_video_url_from_response(response) -> Optional[str]: +def get_video_url_from_response(response) -> str | None: """Returns the first video url from the Kling video generation task result. Will not raise an error if the response is not valid. """ @@ -315,7 +311,7 @@ def get_images_from_response(response) -> list[KlingImageResult]: return images -def get_images_urls_from_response(response) -> Optional[str]: +def get_images_urls_from_response(response) -> str | None: """Returns the list of image urls from the Kling image generation task result. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. """ @@ -349,7 +345,7 @@ async def execute_text2video( model_mode: str, duration: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) task_creation_response = await sync_op( @@ -394,8 +390,8 @@ async def execute_image2video( model_mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_input_image(start_frame) @@ -451,8 +447,8 @@ async def execute_video_effect( model_name: str, duration: KlingVideoGenDuration, image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - model_mode: Optional[KlingVideoGenMode] = None, + image_2: torch.Tensor | None = None, + model_mode: KlingVideoGenMode | None = None, ) -> tuple[VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( @@ -499,13 +495,13 @@ async def execute_video_effect( async def execute_lipsync( cls: type[IO.ComfyNode], - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - model_mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, + video: Input.Video, + audio: Input.Audio | None = None, + voice_language: str | None = None, + model_mode: str | None = None, + text: str | None = None, + voice_speed: float | None = None, + voice_id: str | None = None, ) -> IO.NodeOutput: if text: validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) @@ -787,7 +783,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): negative_prompt: str, cfg_scale: float, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: return await execute_text2video( cls, @@ -854,8 +850,8 @@ class KlingImage2VideoNode(IO.ComfyNode): mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: return await execute_image2video( cls, @@ -965,15 +961,11 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), - IO.Combo.Input( - "aspect_ratio", - options=[i.value for i in KlingVideoGenAspectRatio], - default="16:9", - ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input( "mode", options=modes, - default=modes[2], + default=modes[8], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1254,8 +1246,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, - audio: AudioInput, + video: Input.Video, + audio: Input.Audio, voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( @@ -1314,7 +1306,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, text: str, voice: str, voice_speed: float, @@ -1471,7 +1463,7 @@ class KlingImageGenerationNode(IO.ComfyNode): human_fidelity: float, n: int, aspect_ratio: KlingImageGenAspectRatio, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) From ca7808f240d4d53e594d3b95753240313864c992 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 29 Nov 2025 05:43:17 +0900 Subject: [PATCH 169/231] fix(user_manager): fix typo in move_userdata dest validation (#10967) Check `dest` instead of `source` when validating destination path in move_userdata endpoint. --- app/user_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/user_manager.py b/app/user_manager.py index a2d376c0c..675f6c0c6 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -424,7 +424,7 @@ class UserManager(): return source dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): + if not isinstance(dest, str): return dest overwrite = request.query.get("overwrite", 'true') != "false" From f55c98a89f76fc06c435a728bc2e76b6b4051463 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:16:46 -0800 Subject: [PATCH 170/231] Disable offload stream when torch compile. (#10961) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 38c506df5..d8ce80010 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1042,6 +1042,9 @@ def get_offload_stream(device): if NUM_STREAMS == 0: return None + if torch.compiler.is_compiling(): + return None + if device in STREAMS: ss = STREAMS[device] #Sync the oldest stream in the queue with the current From 6484ac89dc683b178d9ef3f7406800f7132147ba Mon Sep 17 00:00:00 2001 From: Urle Sistiana <55231606+urlesistiana@users.noreply.github.com> Date: Sat, 29 Nov 2025 05:33:07 +0800 Subject: [PATCH 171/231] fix QuantizedTensor.is_contiguous (#10956) (#10959) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 9b924560b..bb1fb860c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -235,8 +235,8 @@ class QuantizedTensor(torch.Tensor): def is_pinned(self): return self._qdata.is_pinned() - def is_contiguous(self): - return self._qdata.is_contiguous() + def is_contiguous(self, *arg, **kwargs): + return self._qdata.is_contiguous(*arg, **kwargs) # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) From 0ff0457892467ef8a6ea235dcd0618c10ca44ee3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:38:12 +1000 Subject: [PATCH 172/231] mm: wrap the raw stream in context manager (#10958) The documentation of torch.foo.Stream being usable with with: suggests it starts at version 2.7. Use the old API for backwards compatibility. --- comfy/model_management.py | 19 +++++++++++++++---- comfy/ops.py | 2 ++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8ce80010..aeddbaefe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1055,7 +1055,9 @@ def get_offload_stream(device): elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) + s1 = torch.cuda.Stream(device=device, priority=0) + s1.as_context = torch.cuda.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1063,7 +1065,9 @@ def get_offload_stream(device): elif is_device_xpu(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.xpu.Stream(device=device, priority=0)) + s1 = torch.xpu.Stream(device=device, priority=0) + s1.as_context = torch.xpu.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None or weight.dtype == dtype: return weight if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) + if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1..61a2f0754 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if offload_stream is not None: wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) else: wf_context = contextlib.nullcontext() From 065a2fbbec6af5c8e19a3add29703f83faf672d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 16:37:39 -0800 Subject: [PATCH 173/231] Update driver link in AMD portable README (#10974) --- .ci/windows_amd_base_files/README_VERY_IMPORTANT.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 96a500be2..2cbb00d99 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,5 +1,5 @@ -As of the time of writing this you need this preview driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html +As of the time of writing this you need this driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html HOW TO RUN: @@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + From b9070857092a78cc952d70025fdcc0ff540d96ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:40:19 +0200 Subject: [PATCH 174/231] Support video tiny VAEs (#10884) * Support video tiny VAEs * lighttaew scaling fix * Also support video taes in previews Only first frame for now as live preview playback is currently only available through VHS custom nodes. * Support Wan 2.1 lightVAE * Relocate elif block and set Wan VAE dim directly without using pruning rate for lightvae --- comfy/latent_formats.py | 5 +- comfy/sd.py | 34 +++++++- comfy/taesd/taehv.py | 171 ++++++++++++++++++++++++++++++++++++++++ latent_preview.py | 28 +++++-- nodes.py | 18 ++++- 5 files changed, 244 insertions(+), 12 deletions(-) create mode 100644 comfy/taesd/taehv.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8e110f45d..f1ca0151e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -431,6 +431,7 @@ class HunyuanVideo(LatentFormat): ] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + taesd_decoder_name = "taehv" class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 @@ -494,7 +495,7 @@ class Wan21(LatentFormat): ]).view(1, self.latent_channels, 1, 1, 1) - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = "lighttaew2_1" def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -565,6 +566,7 @@ class Wan22(Wan21): def __init__(self): self.scale_factor = 1.0 + self.taesd_decoder_name = "lighttaew2_2" self.latents_mean = torch.tensor([ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, @@ -719,6 +721,7 @@ class HunyuanVideo15(LatentFormat): latent_channels = 32 latent_dimensions = 3 scale_factor = 1.03682 + taesd_decoder_name = "lighttaehy1_5" class Hunyuan3Dv2(LatentFormat): latent_channels = 64 diff --git a/comfy/sd.py b/comfy/sd.py index 350fae92b..9eeb0c45a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -60,6 +60,8 @@ import comfy.lora_convert import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd +import comfy.taesd.taehv +import comfy.latent_formats import comfy.ldm.flux.redux @@ -508,13 +510,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE + dim = sd["decoder.head.0.gamma"].shape[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -584,6 +587,35 @@ class VAE: self.process_input = lambda audio: audio self.working_dtypes = [torch.float32] self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format=comfy.latent_formats.HunyuanVideo + else: + latent_format=None # lighttaew2_1 doesn't need scaling + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py new file mode 100644 index 000000000..3dfe1e4d4 --- /dev/null +++ b/comfy/taesd/taehv.py @@ -0,0 +1,171 @@ +# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple, deque + +import comfy.ops +operations=comfy.ops.disable_weight_init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + + B, T, C, H, W = x.shape + if parallel: + x = x.reshape(B*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + BT, C, H, W = x.shape + T = BT // B + _x = x.reshape(B, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + BT, C, H, W = x.shape + T = BT // B + x = x.view(B, T, C, H, W) + else: + out = [] + work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + progress_bar = tqdm(range(T), disable=not show_progress_bar) + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.popleft() + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + del xt + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt.detach().clone() + else: + xt_new = b(xt, mem[i]) + mem[i] = xt.detach().clone() + del xt + work_queue.appendleft(TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt.detach().clone()) + if len(mem[i]) == b.stride: + B, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + mem[i] = [] + work_queue.appendleft(TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i+1)) + del xt + else: + xt = b(xt) + work_queue.appendleft(TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + super().__init__() + self.image_channels = 3 + self.patch_size = 1 + self.latent_channels = latent_channels + self.parallel = parallel + self.latent_format = latent_format + self.show_progress_bar = show_progress_bar + self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) + self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + self.patch_size = 2 + if self.latent_channels == 32: # HunyuanVideo1.5 + act_func = nn.LeakyReLU(0.2, inplace=True) + else: # HunyuanVideo, Wan 2.1 + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels*self.patch_size**2, 64), act_func, + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), act_func, + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + ) + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value + + def encode(self, x, **kwargs): + if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + return self.process_out(x) + + def decode(self, x, **kwargs): + x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) + if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/latent_preview.py b/latent_preview.py index ddf6dcf49..66bded4b9 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -2,17 +2,24 @@ import torch from PIL import Image from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD +from comfy.sd import VAE import comfy.model_management import folder_paths import comfy.utils import logging MAX_PREVIEW_RESOLUTION = args.preview_size +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] -def preview_to_image(latent_image): - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ) +def preview_to_image(latent_image, do_scale=True): + if do_scale: + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + else: + latents_ubyte = (latent_image.clamp(0, 1) + .mul(0xFF) # to 0..255 + ) if comfy.model_management.directml_enabled: latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) @@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) return preview_to_image(x_sample) +class TAEHVPreviewerImpl(TAESDPreviewerImpl): + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decode(x0[:1, :, :1])[0][0] + return preview_to_image(x_sample, do_scale=False) class Latent2RGBPreviewer(LatentPreviewer): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): @@ -81,8 +92,13 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) - previewer = TAESDPreviewerImpl(taesd) + if latent_format.taesd_decoder_name in VIDEO_TAES: + taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path)) + taesd.first_stage_model.show_progress_bar = False + previewer = TAEHVPreviewerImpl(taesd) + else: + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) + previewer = TAESDPreviewerImpl(taesd) else: logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) diff --git a/nodes.py b/nodes.py index bf73eb90e..495dec806 100644 --- a/nodes.py +++ b/nodes.py @@ -692,8 +692,10 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod - def vae_list(): + def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") sdxl_taesd_enc = False @@ -722,6 +724,11 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + else: + for tae in s.video_taes: + if v.startswith(tae): + vaes.append(v) + if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: @@ -765,7 +772,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + return {"required": { "vae_name": (s.vae_list(s), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -776,10 +783,13 @@ class VAELoader: if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) - elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + elif vae_name in self.image_taes: sd = self.load_taesd(vae_name) else: - vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + if os.path.splitext(vae_name)[0] in self.video_taes: + vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) + else: + vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() From 52a32e2b323b90295ab05a8c299590c890f2ecb6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 18:12:42 -0800 Subject: [PATCH 175/231] Support some z image lora formats. (#10978) --- comfy/lora.py | 8 ++++++++ comfy/utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 36d26293a..360cd128f 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -313,6 +313,14 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Lumina2): + diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 4bd281057..21bd6e8cf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,6 +675,51 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): return key_map +def z_image_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("n_layers", 0) + hidden_size = mmdit_config.get("dim", 0) + + key_map = {} + + for index in range(n_layers): + prefix_from = "layers.{}".format(index) + prefix_to = "{}layers.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attention.".format(prefix_from) + qkv = "{}.attention.qkv.{}".format(prefix_to, end) + + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attention.norm_q.weight": "attention.q_norm.weight", + "attention.norm_k.weight": "attention.k_norm.weight", + "attention.to_out.0.weight": "attention.out.weight", + "attention.to_out.0.bias": "attention.out.bias", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + MAP_BASIC = { + # Final layer + ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), + ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), + ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), + # X embedder + ("x_embedder.weight", "all_x_embedder.2-1.weight"), + ("x_embedder.bias", "all_x_embedder.2-1.bias"), + } + + for k in MAP_BASIC: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) From af96d9812de3e420abd43275d9a5960535b6333c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 29 Nov 2025 11:28:42 +0900 Subject: [PATCH 176/231] feat(security): add System User protection with `__` prefix (#10966) * feat(security): add System User protection with `__` prefix Add protected namespace for custom nodes to store sensitive data (API keys, licenses) that cannot be accessed via HTTP endpoints. Key changes: - New API: get_system_user_directory() for internal access - New API: get_public_user_directory() with structural blocking - 3-layer defense: header validation, path blocking, creation prevention - 54 tests covering security, edge cases, and backward compatibility System Users use `__` prefix (e.g., __system, __cache) following Python's private member convention. They exist in user_directory/ but are completely blocked from /userdata HTTP endpoints. * style: remove unused imports --- app/user_manager.py | 21 +- folder_paths.py | 65 +++ .../app_test/user_manager_system_user_test.py | 193 +++++++++ .../folder_paths_test/system_user_test.py | 206 ++++++++++ .../system_user_endpoint_test.py | 375 ++++++++++++++++++ 5 files changed, 855 insertions(+), 5 deletions(-) create mode 100644 tests-unit/app_test/user_manager_system_user_test.py create mode 100644 tests-unit/folder_paths_test/system_user_test.py create mode 100644 tests-unit/prompt_server_test/system_user_endpoint_test.py diff --git a/app/user_manager.py b/app/user_manager.py index 675f6c0c6..e2c00dab2 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -59,6 +59,9 @@ class UserManager(): user = "default" if args.multi_user and "comfy-user" in request.headers: user = request.headers["comfy-user"] + # Block System Users (use same error message to prevent probing) + if user.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise KeyError("Unknown user: " + user) if user not in self.users: raise KeyError("Unknown user: " + user) @@ -66,15 +69,16 @@ class UserManager(): return user def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() - if type == "userdata": - root_dir = user_directory + root_dir = folder_paths.get_user_directory() else: raise KeyError("Unknown filepath type:" + type) user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + user_root = folder_paths.get_public_user_directory(user) + if user_root is None: + return None + path = user_root # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -101,7 +105,11 @@ class UserManager(): name = name.strip() if not name: raise ValueError("username not provided") + if name.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = user_id + "_" + str(uuid.uuid4()) self.users[user_id] = name @@ -132,7 +140,10 @@ class UserManager(): if username in self.users.values(): return web.json_response({"error": "Duplicate username."}, status=400) - user_id = self.add_user(username) + try: + user_id = self.add_user(username) + except ValueError as e: + return web.json_response({"error": str(e)}, status=400) return web.json_response(user_id) @routes.get("/userdata") diff --git a/folder_paths.py b/folder_paths.py index ffdc4d020..9c96540e3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None: user_directory = user_dir +# System User Protection - Protects system directories from HTTP endpoint access +# System Users are internal-only users that cannot be accessed via HTTP endpoints. +# They use the '__' prefix convention (similar to Python's private member convention). +SYSTEM_USER_PREFIX = "__" + + +def get_system_user_directory(name: str = "system") -> str: + """ + Get the path to a System User directory. + + System User directories (prefixed with '__') are only accessible via internal API, + not through HTTP endpoints. Use this for storing system-internal data that + should not be exposed to users. + + Args: + name: System user name (e.g., "system", "cache"). Must be alphanumeric + with underscores allowed, but cannot start with underscore. + + Returns: + Absolute path to the system user directory. + + Raises: + ValueError: If name is empty, invalid, or starts with underscore. + + Example: + >>> get_system_user_directory("cache") + '/path/to/user/__cache' + """ + if not name or not isinstance(name, str): + raise ValueError("System user name cannot be empty") + if not name.replace("_", "").isalnum(): + raise ValueError(f"Invalid system user name: '{name}'") + if name.startswith("_"): + raise ValueError("System user name should not start with underscore") + return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}") + + +def get_public_user_directory(user_id: str) -> str | None: + """ + Get the path to a Public User directory for HTTP endpoint access. + + This function provides structural security by returning None for any + System User (prefixed with '__'). All HTTP endpoints should use this + function instead of directly constructing user paths. + + Args: + user_id: User identifier from HTTP request. + + Returns: + Absolute path to the user directory, or None if user_id is invalid + or refers to a System User. + + Example: + >>> get_public_user_directory("default") + '/path/to/user/default' + >>> get_public_user_directory("__system") + None + """ + if not user_id or not isinstance(user_id, str): + return None + if user_id.startswith(SYSTEM_USER_PREFIX): + return None + return os.path.join(get_user_directory(), user_id) + + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": diff --git a/tests-unit/app_test/user_manager_system_user_test.py b/tests-unit/app_test/user_manager_system_user_test.py new file mode 100644 index 000000000..63b1ac5e5 --- /dev/null +++ b/tests-unit/app_test/user_manager_system_user_test.py @@ -0,0 +1,193 @@ +"""Tests for System User Protection in user_manager.py + +Tests cover: +- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers +- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory() +- add_user(): 3rd defense layer - prevents creation of System User names +- Defense layers integration tests +""" + +import pytest +from unittest.mock import MagicMock, patch +import tempfile + +import folder_paths +from app.user_manager import UserManager + + +@pytest.fixture +def mock_user_directory(): + """Create a temporary user directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(temp_dir) + yield temp_dir + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager(mock_user_directory): + """Create a UserManager instance for testing.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + manager = UserManager() + # Add a default user for testing + manager.users = {"default": "default", "test_user_123": "Test User"} + yield manager + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock() + request.headers = {} + return request + + +class TestGetRequestUserId: + """Tests for get_request_user_id() - 1st defense layer. + + Verifies: + - System Users (__ prefix) in HTTP header are rejected with KeyError + - Public Users pass through successfully + """ + + def test_system_user_raises_error(self, user_manager, mock_request): + """Test System User in header raises KeyError.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_system_user_cache_raises_error(self, user_manager, mock_request): + """Test System User cache raises KeyError.""" + mock_request.headers = {"comfy-user": "__cache"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_normal_user_works(self, user_manager, mock_request): + """Test normal user access works.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + user_id = user_manager.get_request_user_id(mock_request) + assert user_id == "default" + + def test_unknown_user_raises_error(self, user_manager, mock_request): + """Test unknown user raises KeyError.""" + mock_request.headers = {"comfy-user": "unknown_user"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + +class TestGetRequestUserFilepath: + """Tests for get_request_user_filepath() - 2nd defense layer. + + Verifies: + - Returns None when get_public_user_directory() returns None (System User) + - Acts as backup defense if 1st layer is bypassed + """ + + def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory): + """Test System User returns None (structural blocking).""" + # First, we need to mock get_request_user_id to return System User + # But actually, get_request_user_id will raise KeyError first + # So we test via get_public_user_directory returning None + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Patch get_public_user_directory to return None for testing + with patch.object(folder_paths, 'get_public_user_directory', return_value=None): + result = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert result is None + + def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory): + """Test normal user gets valid filepath.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + path = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert path is not None + assert "default" in path + assert path.endswith("test.txt") + + +class TestAddUser: + """Tests for add_user() - 3rd defense layer (creation-time blocking). + + Verifies: + - System User name (__ prefix) creation is rejected with ValueError + - Sanitized usernames that become System User are also rejected + """ + + def test_system_user_prefix_name_raises(self, user_manager): + """Test System User prefix in name raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__system") + + def test_system_user_prefix_cache_raises(self, user_manager): + """Test System User cache prefix raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__cache") + + def test_sanitized_system_user_prefix_raises(self, user_manager): + """Test sanitized name becoming System User prefix raises ValueError (bypass prevention).""" + # "__test" directly starts with System User prefix + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__test") + + def test_normal_user_creation(self, user_manager, mock_user_directory): + """Test normal user creation works.""" + user_id = user_manager.add_user("Normal User") + assert user_id is not None + assert not user_id.startswith("__") + assert "Normal-User" in user_id or "Normal_User" in user_id + + def test_empty_name_raises(self, user_manager): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user("") + + def test_whitespace_only_raises(self, user_manager): + """Test whitespace-only name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user(" ") + + +class TestDefenseLayers: + """Integration tests for all three defense layers. + + Verifies: + - Each defense layer blocks System Users independently + - System User bypass is impossible through any layer + """ + + def test_layer1_get_request_user_id(self, user_manager, mock_request): + """Test 1st defense layer blocks System Users.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError): + user_manager.get_request_user_id(mock_request) + + def test_layer2_get_public_user_directory(self): + """Test 2nd defense layer blocks System Users.""" + result = folder_paths.get_public_user_directory("__system") + assert result is None + + def test_layer3_add_user(self, user_manager): + """Test 3rd defense layer blocks System User creation.""" + with pytest.raises(ValueError): + user_manager.add_user("__system") diff --git a/tests-unit/folder_paths_test/system_user_test.py b/tests-unit/folder_paths_test/system_user_test.py new file mode 100644 index 000000000..cd46459f1 --- /dev/null +++ b/tests-unit/folder_paths_test/system_user_test.py @@ -0,0 +1,206 @@ +"""Tests for System User Protection in folder_paths.py + +Tests cover: +- get_system_user_directory(): Internal API for custom nodes to access System User directories +- get_public_user_directory(): HTTP endpoint access with System User blocking +- Backward compatibility: Existing APIs unchanged +- Security: Path traversal and injection prevention +""" + +import pytest +import os +import tempfile + +from folder_paths import ( + get_system_user_directory, + get_public_user_directory, + get_user_directory, + set_user_directory, +) + + +@pytest.fixture(scope="module") +def mock_user_directory(): + """Create a temporary user directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = get_user_directory() + set_user_directory(temp_dir) + yield temp_dir + set_user_directory(original_dir) + + +class TestGetSystemUserDirectory: + """Tests for get_system_user_directory() - internal API for System User directories. + + Verifies: + - Custom nodes can access System User directories via internal API + - Input validation prevents path traversal attacks + """ + + def test_default_name(self, mock_user_directory): + """Test default 'system' name.""" + path = get_system_user_directory() + assert path.endswith("__system") + assert mock_user_directory in path + + def test_custom_name(self, mock_user_directory): + """Test custom system user name.""" + path = get_system_user_directory("cache") + assert path.endswith("__cache") + assert "__cache" in path + + def test_name_with_underscore(self, mock_user_directory): + """Test name with underscore in middle.""" + path = get_system_user_directory("my_cache") + assert "__my_cache" in path + + def test_empty_name_raises(self): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory("") + + def test_none_name_raises(self): + """Test None name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory(None) + + def test_name_starting_with_underscore_raises(self): + """Test name starting with underscore raises ValueError.""" + with pytest.raises(ValueError, match="should not start with underscore"): + get_system_user_directory("_system") + + def test_path_traversal_raises(self): + """Test path traversal attempt raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("../escape") + + def test_path_traversal_middle_raises(self): + """Test path traversal in middle raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system/../other") + + def test_special_chars_raise(self): + """Test special characters raise ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system!") + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_system_user_directory("test") + assert os.path.isabs(path) + + +class TestGetPublicUserDirectory: + """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking. + + Verifies: + - System Users (__ prefix) return None, blocking HTTP access + - Public Users get valid paths + - New endpoints using this function are automatically protected + """ + + def test_normal_user(self, mock_user_directory): + """Test normal user returns valid path.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + assert mock_user_directory in path + + def test_system_user_returns_none(self): + """Test System User (__ prefix) returns None - blocks HTTP access.""" + assert get_public_user_directory("__system") is None + + def test_system_user_cache_returns_none(self): + """Test System User cache returns None.""" + assert get_public_user_directory("__cache") is None + + def test_empty_user_returns_none(self): + """Test empty user returns None.""" + assert get_public_user_directory("") is None + + def test_none_user_returns_none(self): + """Test None user returns None.""" + assert get_public_user_directory(None) is None + + def test_header_injection_returns_none(self): + """Test header injection attempt returns None (security).""" + assert get_public_user_directory("__system\r\nX-Injected: true") is None + + def test_null_byte_injection_returns_none(self): + """Test null byte injection handling (security).""" + # Note: startswith check happens before any path operations + result = get_public_user_directory("user\x00__system") + # This should return a path since it doesn't start with __ + # The actual security comes from the path not being __* + assert result is not None or result is None # Depends on validation + + def test_path_traversal_attempt(self, mock_user_directory): + """Test path traversal attempt handling.""" + # This function doesn't validate paths, only reserved prefix + # Path traversal should be handled by the caller + path = get_public_user_directory("../../../etc/passwd") + # Returns path but doesn't start with __, so not None + # Actual path validation happens in user_manager + assert path is not None or "__" not in "../../../etc/passwd" + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_public_user_directory("testuser") + assert path is not None + assert os.path.isabs(path) + + +class TestBackwardCompatibility: + """Tests for backward compatibility with existing APIs. + + Verifies: + - get_user_directory() API unchanged + - Existing user data remains accessible + """ + + def test_get_user_directory_unchanged(self, mock_user_directory): + """Test get_user_directory() still works as before.""" + user_dir = get_user_directory() + assert user_dir is not None + assert os.path.isabs(user_dir) + assert user_dir == mock_user_directory + + def test_existing_user_accessible(self, mock_user_directory): + """Test existing users can access their directories.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + + +class TestEdgeCases: + """Tests for edge cases in System User detection. + + Verifies: + - Only __ prefix is blocked (not _, not middle __) + - Bypass attempts are prevented + """ + + def test_prefix_only(self): + """Test prefix-only string is blocked.""" + assert get_public_user_directory("__") is None + + def test_single_underscore_allowed(self): + """Test single underscore prefix is allowed (not System User).""" + path = get_public_user_directory("_system") + assert path is not None + assert "_system" in path + + def test_triple_underscore_blocked(self): + """Test triple underscore is blocked (starts with __).""" + assert get_public_user_directory("___system") is None + + def test_underscore_in_middle_allowed(self): + """Test underscore in middle is allowed.""" + path = get_public_user_directory("my__system") + assert path is not None + assert "my__system" in path + + def test_leading_space_allowed(self): + """Test leading space + prefix is allowed (doesn't start with __).""" + path = get_public_user_directory(" __system") + assert path is not None diff --git a/tests-unit/prompt_server_test/system_user_endpoint_test.py b/tests-unit/prompt_server_test/system_user_endpoint_test.py new file mode 100644 index 000000000..22ac00af9 --- /dev/null +++ b/tests-unit/prompt_server_test/system_user_endpoint_test.py @@ -0,0 +1,375 @@ +"""E2E Tests for System User Protection HTTP Endpoints + +Tests cover: +- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move) +- User creation blocking: System User names cannot be created via POST /users +- Backward compatibility: Public Users work as before +- Custom node scenario: Internal API works while HTTP is blocked +- Structural security: get_public_user_directory() provides automatic protection +""" + +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch +import folder_paths + + +@pytest.fixture +def mock_user_directory(tmp_path): + """Create a temporary user directory.""" + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(str(tmp_path)) + yield tmp_path + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager_multi_user(mock_user_directory): + """Create UserManager in multi-user mode.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + um = UserManager() + # Add test users + um.users = {"default": "default", "test_user_123": "Test User"} + yield um + + +@pytest.fixture +def app_multi_user(user_manager_multi_user): + """Create app with multi-user mode enabled.""" + app = web.Application() + routes = web.RouteTableDef() + user_manager_multi_user.add_routes(routes) + app.add_routes(routes) + return app + + +class TestSystemUserEndpointBlocking: + """E2E tests for System User blocking on all HTTP endpoints. + + Verifies: + - GET /userdata blocked for System Users + - POST /userdata blocked for System Users + - DELETE /userdata blocked for System Users + - POST /userdata/.../move/... blocked for System Users + """ + + @pytest.mark.asyncio + async def test_userdata_get_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /userdata with System User header should be blocked. + """ + # Create test directory for System User (simulating internal creation) + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "secret.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Attempt to access System User's data via HTTP + resp = await client.get( + "/userdata?dir=.", + headers={"comfy-user": "__system"} + ) + + # Should be blocked (403 Forbidden or similar error) + assert resp.status in [400, 403, 500], \ + f"System User access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_userdata_post_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/test.txt", + headers={"comfy-user": "__system"}, + data=b"malicious content" + ) + + assert resp.status in [400, 403, 500], \ + f"System User write should be blocked, got {resp.status}" + + # Verify no file was created + assert not (mock_user_directory / "__system" / "test.txt").exists() + + @pytest.mark.asyncio + async def test_userdata_delete_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + DELETE /userdata with System User header should be blocked. + """ + # Create a file in System User directory + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + secret_file = system_user_dir / "secret.txt" + secret_file.write_text("do not delete") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.delete( + "/userdata/secret.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User delete should be blocked, got {resp.status}" + + # Verify file still exists + assert secret_file.exists() + + @pytest.mark.asyncio + async def test_v2_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /v2/userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/v2/userdata", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User v2 access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_move_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata/{file}/move/{dest} with System User header should be blocked. + """ + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "source.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/source.txt/move/dest.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User move should be blocked, got {resp.status}" + + # Verify source file still exists (move was blocked) + assert (system_user_dir / "source.txt").exists() + + +class TestSystemUserCreationBlocking: + """E2E tests for blocking System User name creation via POST /users. + + Verifies: + - POST /users returns 400 for System User name (not 500) + """ + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_name( + self, aiohttp_client, app_multi_user + ): + """POST /users with System User name should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + resp = await client.post( + "/users", + json={"username": "__system"} + ) + + assert resp.status == 400, \ + f"System User creation should return 400, got {resp.status}" + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_prefix_variations( + self, aiohttp_client, app_multi_user + ): + """POST /users with any System User prefix variation should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + system_user_names = ["__system", "__cache", "__config", "__anything"] + + for name in system_user_names: + resp = await client.post("/users", json={"username": name}) + assert resp.status == 400, \ + f"System User name '{name}' should return 400, got {resp.status}" + + +class TestPublicUserStillWorks: + """E2E tests for backward compatibility - Public Users should work as before. + + Verifies: + - Public Users can access their data via HTTP + - Public Users can create files via HTTP + """ + + @pytest.mark.asyncio + async def test_public_user_can_access_userdata( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to access their data. + """ + # Create test directory for Public User + user_dir = mock_user_directory / "default" + user_dir.mkdir() + test_dir = user_dir / "workflows" + test_dir.mkdir() + (test_dir / "test.json").write_text('{"test": true}') + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata?dir=workflows", + headers={"comfy-user": "default"} + ) + + assert resp.status == 200 + data = await resp.json() + assert "test.json" in data + + @pytest.mark.asyncio + async def test_public_user_can_create_files( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to create files. + """ + # Create user directory + user_dir = mock_user_directory / "default" + user_dir.mkdir() + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/newfile.txt", + headers={"comfy-user": "default"}, + data=b"user content" + ) + + assert resp.status == 200 + assert (user_dir / "newfile.txt").exists() + + +class TestCustomNodeScenario: + """Tests for custom node use case: internal API access vs HTTP blocking. + + Verifies: + - Internal API (get_system_user_directory) works for custom nodes + - HTTP endpoint cannot access data created via internal API + """ + + def test_internal_api_can_access_system_user(self, mock_user_directory): + """ + Internal API (get_system_user_directory) should work for custom nodes. + """ + # Custom node uses internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + + assert system_path is not None + assert "__mynode_config" in system_path + + # Can create and write to System User directory + os.makedirs(system_path, exist_ok=True) + config_file = os.path.join(system_path, "settings.json") + with open(config_file, "w") as f: + f.write('{"api_key": "secret"}') + + assert os.path.exists(config_file) + + @pytest.mark.asyncio + async def test_http_cannot_access_internal_data( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + HTTP endpoint cannot access data created via internal API. + """ + # Custom node creates data via internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + os.makedirs(system_path, exist_ok=True) + with open(os.path.join(system_path, "secret.json"), "w") as f: + f.write('{"api_key": "secret"}') + + client = await aiohttp_client(app_multi_user) + + # Attacker tries to access via HTTP + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata/secret.json", + headers={"comfy-user": "__mynode_config"} + ) + + # Should be blocked + assert resp.status in [400, 403, 500] + + +class TestStructuralSecurity: + """Tests for structural security pattern. + + Verifies: + - get_public_user_directory() automatically blocks System Users + - New endpoints using this function are automatically protected + """ + + def test_get_public_user_directory_blocks_system_user(self): + """ + Any code using get_public_user_directory() is automatically protected. + """ + # This is the structural security - any new endpoint using this function + # will automatically block System Users + assert folder_paths.get_public_user_directory("__system") is None + assert folder_paths.get_public_user_directory("__cache") is None + assert folder_paths.get_public_user_directory("__anything") is None + + # Public Users work + assert folder_paths.get_public_user_directory("default") is not None + assert folder_paths.get_public_user_directory("user123") is not None + + def test_structural_security_pattern(self, mock_user_directory): + """ + Demonstrate the structural security pattern for new endpoints. + + Any new endpoint should follow this pattern: + 1. Get user from request + 2. Use get_public_user_directory() - automatically blocks System Users + 3. If None, return error + """ + def new_endpoint_handler(user_id: str) -> str | None: + """Example of how new endpoints should be implemented.""" + user_path = folder_paths.get_public_user_directory(user_id) + if user_path is None: + return None # Blocked + return user_path + + # System Users are automatically blocked + assert new_endpoint_handler("__system") is None + assert new_endpoint_handler("__secret") is None + + # Public Users work + assert new_endpoint_handler("default") is not None From 5151cff293607c2191981fd16c62c1b1a6939695 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:55:00 -0800 Subject: [PATCH 177/231] Add some missing z image lora layers. (#10980) --- comfy/lora.py | 9 +++++---- comfy/utils.py | 51 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 360cd128f..3a9077869 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -316,10 +316,11 @@ def model_lora_keys_unet(model, key_map={}): if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: - to = diffusers_keys[k] - key_lora = k[:-len(".weight")] - key_map["diffusion_model.{}".format(key_lora)] = to - key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 21bd6e8cf..37485e497 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -678,17 +678,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): def z_image_to_diffusers(mmdit_config, output_prefix=""): n_layers = mmdit_config.get("n_layers", 0) hidden_size = mmdit_config.get("dim", 0) - + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) key_map = {} - for index in range(n_layers): - prefix_from = "layers.{}".format(index) - prefix_to = "{}layers.{}".format(output_prefix, index) - + def add_block_keys(prefix_from, prefix_to, has_adaln=True): for end in ("weight", "bias"): k = "{}.attention.".format(prefix_from) qkv = "{}.attention.qkv.{}".format(prefix_to, end) - key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) @@ -698,28 +695,52 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): "attention.norm_k.weight": "attention.k_norm.weight", "attention.to_out.0.weight": "attention.out.weight", "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) - for k in block_map: - key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) - MAP_BASIC = { - # Final layer + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), - # X embedder ("x_embedder.weight", "all_x_embedder.2-1.weight"), ("x_embedder.bias", "all_x_embedder.2-1.bias"), - } + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] - for k in MAP_BASIC: - key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) return key_map - def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) From 0a6746898d6864d65e2fc7504e5e875f8c19c0ba Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:00:55 -0800 Subject: [PATCH 178/231] Make the ScaleRope node work on Z Image and Lumina. (#10994) --- comfy/ldm/lumina/model.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 565400b54..7d7e9112c 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -517,11 +517,23 @@ class NextDiT(nn.Module): B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + H_tokens, W_tokens = H // pH, W // pW x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple From 4967f81778f84b41acc40ed03536dd71dd88e5f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sun, 30 Nov 2025 10:07:26 +0800 Subject: [PATCH 179/231] update template to 0.7.25 (#10996) * update template to 0.7.24 * Update template to 0.7.25 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e0b2c566b..386477808 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.32.9 -comfyui-workflow-templates==0.7.23 +comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch torchsde From f8b981ae9a5676311624bbafa636a1874db79459 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 30 Nov 2025 01:21:31 -0800 Subject: [PATCH 180/231] Next AMD portable will have pytorch with ROCm 7.1.1 (#11002) --- .github/workflows/release-stable-all.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 9274b4170..d72ece2ce 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 6.4.4" + name: "Release AMD ROCm 7.1.1" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm644" + cache_tag: "rocm711" python_minor: "12" python_patch: "10" rel_name: "amd" From 7dbd5dfe91f057b83dcba0c127f712f6d71f7def Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 1 Dec 2025 10:27:17 -0800 Subject: [PATCH 181/231] bump comfyui-frontend-package to 1.32.10 (#11018) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 386477808..045b2ac54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.32.9 +comfyui-frontend-package==1.32.10 comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch From 2640acb31ccfddee57ba22d5245bf456e8dffe53 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Dec 2025 14:13:48 -0800 Subject: [PATCH 182/231] Update qwen tokenizer to add qwen 3 tokens. (#11029) Doesn't actually change anything for current workflows because none of the current models have a template with the think tokens. --- .../qwen25_tokenizer/tokenizer_config.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json index 67688e82c..df5b5d7fe 100644 --- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json +++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json @@ -179,36 +179,36 @@ "special": false }, "151665": { - "content": "<|img|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151666": { - "content": "<|endofimg|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151667": { - "content": "<|meta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151668": { - "content": "<|endofmeta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false } }, "additional_special_tokens": [ From 1cb7e22a95701f2619d1ddf5683ea221b58a0c13 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 2 Dec 2025 02:11:52 +0200 Subject: [PATCH 183/231] [API Nodes] add Kling O1 model support (#11025) * feat(api-nodes): add Kling O1 model support * fix: increase max allowed duration to 10.05 seconds * fix(VideoInput): respect "format" argument --- comfy_api/latest/_input_impl/video_types.py | 5 +- comfy_api_nodes/apis/kling_api.py | 66 +++ comfy_api_nodes/nodes_kling.py | 444 +++++++++++++++++++- comfy_api_nodes/util/upload_helpers.py | 3 +- 4 files changed, 499 insertions(+), 19 deletions(-) create mode 100644 comfy_api_nodes/apis/kling_api.py diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index bde37f90a..7231bf13c 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput): raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + extra_kwargs = {} + if format != VideoContainer.AUTO: + extra_kwargs["format"] = format.value + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py new file mode 100644 index 000000000..0a3b447c5 --- /dev/null +++ b/comfy_api_nodes/apis/kling_api.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field + + +class OmniProText2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniParamImage(BaseModel): + image_url: str = Field(...) + type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'") + + +class OmniParamVideo(BaseModel): + video_url: str = Field(...) + refer_type: str | None = Field(..., description="Can be 'base' or 'feature'") + keep_original_sound: str = Field(..., description="'yes' or 'no'") + + +class OmniProFirstLastFrameRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniProReferences2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'") + image_list: list[OmniParamImage] | None = Field( + None, max_length=7, description="Max length 4 when video is present." + ) + video_list: list[OmniParamVideo] | None = Field(None, max_length=1) + duration: str | None = Field(..., description="From 3 to 10.") + prompt: str = Field(...) + mode: str = Field("pro") + + +class TaskStatusVideoResult(BaseModel): + duration: str | None = Field(None, description="Total video duration") + id: str | None = Field(None, description="Generated video ID") + url: str | None = Field(None, description="URL for generated video") + + +class TaskStatusVideoResults(BaseModel): + videos: list[TaskStatusVideoResult] | None = Field(None) + + +class TaskStatusVideoResponseData(BaseModel): + created_at: int | None = Field(None, description="Task creation time") + updated_at: int | None = Field(None, description="Task update time") + task_status: str | None = None + task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") + task_id: str | None = Field(None, description="Task ID") + task_result: TaskStatusVideoResults | None = Field(None) + + +class TaskStatusVideoResponse(BaseModel): + code: int | None = Field(None, description="Error code") + message: str | None = Field(None, description="Error message") + request_id: str | None = Field(None, description="Request ID") + data: TaskStatusVideoResponseData | None = Field(None) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 23a7f55f1..850c44db6 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,13 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -import math import logging - -from typing_extensions import override +import math import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( KlingCameraControl, KlingCameraConfig, @@ -48,23 +48,31 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) +from comfy_api_nodes.apis.kling_api import ( + OmniParamImage, + OmniParamVideo, + OmniProFirstLastFrameRequest, + OmniProReferences2VideoRequest, + OmniProText2VideoRequest, + TaskStatusVideoResponse, +) from comfy_api_nodes.util import ( - validate_image_dimensions, + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, validate_video_dimensions, validate_video_duration, - tensor_to_base64_string, - validate_string, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - upload_video_to_comfyapi, - download_url_to_video_output, - sync_op, - ApiEndpoint, - poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO, Input KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -202,6 +210,20 @@ VOICES_CONFIG = { } +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), + response_model=TaskStatusVideoResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -449,7 +471,7 @@ async def execute_video_effect( image_1: torch.Tensor, image_2: torch.Tensor | None = None, model_mode: KlingVideoGenMode | None = None, -) -> tuple[VideoFromFile, str, str]: +) -> tuple[InputImpl.VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( model_name=model_name, @@ -736,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode): ) +class OmniProTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProTextToVideoNode", + display_name="Kling Omni Text to Video (Pro)", + category="api node/video/Kling", + description="Use text prompts to generate videos with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProFirstLastFrameNode", + display_name="Kling Omni First-Last-Frame to Video (Pro)", + category="api node/video/Kling", + description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("duration", options=["5", "10"]), + IO.Image.Input("first_frame"), + IO.Image.Input( + "end_frame", + optional=True, + tooltip="An optional end frame for the video. " + "This cannot be used simultaneously with 'reference_images'.", + ), + IO.Image.Input( + "reference_images", + optional=True, + tooltip="Up to 6 additional reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image | None = None, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if end_frame is not None and reference_images is not None: + raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [ + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + type="first_frame", + ) + ] + if end_frame is not None: + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_list.append( + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + type="end_frame", + ) + ) + if reference_images is not None: + if get_number_of_images(reference_images) > 6: + raise ValueError("The maximum number of reference images allowed is 6.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageToVideoNode", + display_name="Kling Omni Image to Video (Pro)", + category="api node/video/Kling", + description="Use up to 7 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Image.Input( + "reference_images", + tooltip="Up to 7 reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_images: Input.Image, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if get_number_of_images(reference_images) > 7: + raise ValueError("The maximum number of reference images is 7.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [] + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProVideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProVideoToVideoNode", + display_name="Kling Omni Video to Video (Pro)", + category="api node/video/Kling", + description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Video.Input("reference_video", tooltip="Video to use as a reference."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + refer_type="feature", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProEditVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProEditVideoNode", + display_name="Kling Omni Edit Video (Pro)", + category="api node/video/Kling", + description="Edit an existing video with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + refer_type="base", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -1162,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): category="api node/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ - IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), + IO.Image.Input( + "image", + tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1", + ), IO.Combo.Input( "effect_scene", options=[i.value for i in KlingSingleImageEffectsScene], @@ -1525,6 +1930,11 @@ class KlingExtension(ComfyExtension): KlingImageGenerationNode, KlingSingleImageVideoEffectNode, KlingDualCharacterVideoEffectNode, + OmniProTextToVideoNode, + OmniProFirstLastFrameNode, + OmniProImageToVideoNode, + OmniProVideoToVideoNode, + OmniProEditVideoNode, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b9019841f..0532bea9a 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -103,6 +103,7 @@ async def upload_video_to_comfyapi( container: VideoContainer = VideoContainer.MP4, codec: VideoCodec = VideoCodec.H264, max_duration: Optional[int] = None, + wait_label: str | None = "Uploading", ) -> str: """ Uploads a single video to ComfyUI API and returns its download URL. @@ -127,7 +128,7 @@ async def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) async def upload_file_to_comfyapi( From 30c259cac8c08ff8d015f9aff3151cb525c9b702 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 Dec 2025 20:25:35 -0500 Subject: [PATCH 184/231] ComfyUI version v0.3.76 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index fa4b4f4b0..4b039356e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.75" +__version__ = "0.3.76" diff --git a/pyproject.toml b/pyproject.toml index 9009e65fe..02b94a0ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.75" +version = "0.3.76" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 878db3a727c1c6049bc1c4959cdfabc35eaf3d56 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:56:17 -0800 Subject: [PATCH 185/231] Implement the Ovis image model. (#11030) --- comfy/ldm/chroma/model.py | 3 +- comfy/ldm/flux/layers.py | 68 +++++++++++++++++++++-------------- comfy/ldm/flux/model.py | 21 ++++++++--- comfy/model_detection.py | 10 +++++- comfy/sd.py | 13 +++++-- comfy/text_encoders/llama.py | 31 ++++++++++++++++ comfy/text_encoders/ovis.py | 69 ++++++++++++++++++++++++++++++++++++ nodes.py | 2 +- 8 files changed, 182 insertions(+), 35 deletions(-) create mode 100644 comfy/text_encoders/ovis.py diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a72f8cc47..2e8ef0687 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -40,7 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int - + txt_ids_dims: list + vec_in_dim: int diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 2472ab79c..60f2bdae2 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module): def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -140,7 +169,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module): self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) @@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module): self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) self.flipped_img_txt = flipped_img_txt @@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module): modulation=True, mlp_silu_act=False, bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp if mlp_silu_act: self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) self.mlp_act = SiLUActivation() else: self.mlp_act = nn.GELU(approximate="tanh") + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out @@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v # compute activation in mlp stream, cat again and run second linear layer - mlp = self.mlp_act(mlp) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index d5674dea6..f40c2a7a9 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,7 +15,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, - Modulation + Modulation, + RMSNorm ) @dataclass @@ -34,11 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list global_modulation: bool = False mlp_silu_act: bool = False ops_bias: bool = True default_ref_method: str = "offset" ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -76,6 +80,11 @@ class Flux(nn.Module): ) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None + self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -86,6 +95,7 @@ class Flux(nn.Module): modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -94,7 +104,7 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -150,6 +160,8 @@ class Flux(nn.Module): y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) vec_orig = vec @@ -332,8 +344,9 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) - if len(self.params.axes_dim) == 4: # Flux 2 - txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7afe4a798..7d0517e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 2000 dit_config["out_channels"] = 128 dit_config["global_modulation"] = True - dit_config["vec_in_dim"] = None dit_config["mlp_silu_act"] = True dit_config["qkv_bias"] = False dit_config["ops_bias"] = False dit_config["default_ref_method"] = "index" dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] patch_size = 1 else: dit_config["image_model"] = "flux" @@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 10000 dit_config["out_channels"] = 16 dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] patch_size = 2 dit_config["in_channels"] = 16 @@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + else: + dit_config["vec_in_dim"] = None dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') @@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview diff --git a/comfy/sd.py b/comfy/sd.py index 9eeb0c45a..f9e5efab5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.ovis import comfy.model_patcher import comfy.lora @@ -956,6 +957,7 @@ class CLIPType(Enum): QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 + OVIS = 21 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -987,6 +989,7 @@ class TEModel(Enum): MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 + QWEN3_2B = 17 def detect_te_model(sd): @@ -1020,9 +1023,12 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1150,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_4B: clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index cd4b5f76c..0d07ac8c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -100,6 +100,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..81c9bd51c --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,69 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/nodes.py b/nodes.py index 495dec806..d5e5dc228 100644 --- a/nodes.py +++ b/nodes.py @@ -939,7 +939,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From c55dc857d5da5af203caf720ed7056047d382544 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 1 Dec 2025 17:56:38 -0800 Subject: [PATCH 186/231] bump comfyui-frontend-package to 1.33.10 (#11028) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 045b2ac54..f98848e20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.32.10 +comfyui-frontend-package==1.33.10 comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch From b4a20acc54b0b94dc05a1bd09dc0b54dd12203f1 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:32:52 +0900 Subject: [PATCH 187/231] feat: Support ComfyUI-Manager for pip version (#7555) --- comfy/cli_args.py | 7 +++++++ comfy_api/feature_flags.py | 1 + main.py | 30 ++++++++++++++++++++++++++++++ manager_requirements.txt | 1 + nodes.py | 9 +++++++++ server.py | 8 +++++++- 6 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 manager_requirements.txt diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5f0dfaa10..209fc185b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") +parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") +manager_group = parser.add_mutually_exclusive_group() +manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") + + vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") @@ -168,6 +174,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0d4389a6e..bfb77eb5f 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -13,6 +13,7 @@ from comfy.cli_args import args SERVER_FEATURE_FLAGS: Dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes + "extension": {"manager": {"supports_v4": True}}, } diff --git a/main.py b/main.py index e1b0f1620..0cd815d9e 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags + if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -22,6 +23,23 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + +def handle_comfyui_manager_unavailable(): + if not args.windows_standalone_build: + logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + args.enable_manager = False + + +if args.enable_manager: + if importlib.util.find_spec("comfyui_manager"): + import comfyui_manager + + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + handle_comfyui_manager_unavailable() + else: + handle_comfyui_manager_unavailable() + + def apply_custom_paths(): # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") @@ -79,6 +97,11 @@ def execute_prestartup_script(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + continue + if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": continue @@ -101,6 +124,10 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() + +if args.enable_manager: + comfyui_manager.prestartup() + execute_prestartup_script() @@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: + comfyui_manager.start() + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, diff --git a/manager_requirements.txt b/manager_requirements.txt new file mode 100644 index 000000000..52cc5389c --- /dev/null +++ b/manager_requirements.txt @@ -0,0 +1 @@ +comfyui_manager==4.0.3b3 diff --git a/nodes.py b/nodes.py index d5e5dc228..4c910a34b 100644 --- a/nodes.py +++ b/nodes.py @@ -43,6 +43,9 @@ import folder_paths import latent_preview import node_helpers +if args.enable_manager: + import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -2243,6 +2246,12 @@ async def init_external_custom_nodes(): if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + logging.info(f"Blocked by policy: {module_path}") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) diff --git a/server.py b/server.py index fca5050bd..e3bd056d9 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,9 @@ from protocol import BinaryEventTypes # Import cache control middleware from middleware.cache_middleware import cache_control +if args.enable_manager: + import comfyui_manager + async def send_socket_catch_exception(function, message): try: await function(message) @@ -212,6 +215,9 @@ class PromptServer(): if args.disable_api_nodes: middlewares.append(create_block_external_middleware()) + if args.enable_manager: + middlewares.append(comfyui_manager.create_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() @@ -599,7 +605,7 @@ class PromptServer(): system_stats = { "system": { - "os": os.name, + "os": sys.platform, "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, From a17cf1c3871ad582c85c2bb6fddb63ec9c6df0ce Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Mon, 1 Dec 2025 19:40:44 -0800 Subject: [PATCH 188/231] Add @guill as a code owner (#11031) --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/CODEOWNERS b/CODEOWNERS index b7aca9b26..51acc4986 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,4 @@ # Admins * @comfyanonymous * @kosinkadink +* @guill From 44baa0b7f32dd0c2ff0a9898aeb6c7929d855cd3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Dec 2025 11:46:29 -0800 Subject: [PATCH 189/231] Fix CODEOWNERS formatting to have all on the same line, otherwise only last line applies (#11053) --- CODEOWNERS | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 51acc4986..4d5448636 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,4 +1,2 @@ # Admins -* @comfyanonymous -* @kosinkadink -* @guill +* @comfyanonymous @kosinkadink @guill From 33d6aec3b70bc6f3e5bba26c85bd8f3bb1380d08 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:50:13 +0200 Subject: [PATCH 190/231] add check for the format arg type in VideoFromComponents.save_to function (#11046) * add check for the format var type in VideoFromComponents.save_to function * convert "format" to VideoContainer enum --- comfy_api/latest/_input_impl/video_types.py | 2 +- comfy_extras/nodes_video.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 7231bf13c..a4cd3737d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -337,7 +337,7 @@ class VideoFromComponents(VideoInput): if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") extra_kwargs = {} - if format != VideoContainer.AUTO: + if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 69fabb12e..6cf6e39bf 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: + def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode): file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=format, + format=VideoContainer(format), codec=codec, metadata=saved_metadata ) From daaceac769a1355ab975758ede064317ea7514b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:11:58 -0800 Subject: [PATCH 191/231] Hack to make zimage work in fp16. (#11057) --- comfy/ldm/lumina/model.py | 18 +++++++++++------- comfy/supported_models.py | 2 ++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 7d7e9112c..070b5da09 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -22,6 +22,10 @@ def modulate(x, scale): # Core NextDiT Model # ############################################################################# +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x class JointAttention(nn.Module): """Multi-head attention module.""" @@ -169,7 +173,7 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + return clamp_fp16(F.silu(x1) * x3) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module): scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( - self.attention( + clamp_fp16(self.attention( modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( - self.feed_forward( + clamp_fp16(self.feed_forward( modulate(self.ffn_norm1(x), scale_mlp), - ) + )) ) else: assert adaln_input is None x = x + self.attention_norm2( - self.attention( + clamp_fp16(self.attention( self.attention_norm1(x), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + self.ffn_norm2( self.feed_forward( diff --git a/comfy/supported_models.py b/comfy/supported_models.py index af8120400..afd97160b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1027,6 +1027,8 @@ class ZImage(Lumina2): memory_usage_factor = 1.7 + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) From 277237ccc1499bac7fcd221a666dfe7a32ac4206 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 08:24:19 +1000 Subject: [PATCH 192/231] attention: use flag based OOM fallback (#11038) Exception ref all local variables for the lifetime of exception context. Just set a flag and then if to dump the exception before falling back. --- comfy/ldm/modules/attention.py | 3 +++ comfy/ldm/modules/diffusionmodules/model.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..a8800ded0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + exception_fallback = True + if exception_fallback: if tensor_layout == "NHD": q, k, v = map( lambda t: t.transpose(1, 2), diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 4245eedca..de1e01cc8 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -279,6 +279,7 @@ def pytorch_attention(q, k, v): orig_shape = q.shape B = orig_shape[0] C = orig_shape[1] + oom_fallback = False q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), @@ -289,6 +290,8 @@ def pytorch_attention(q, k, v): out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") + oom_fallback = True + if oom_fallback: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out From b94d394a64dd0af06bca44b96c66549bb463331d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:38:31 -0800 Subject: [PATCH 193/231] Support Z Image alibaba pai fun controlnets. (#11062) These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it. --- comfy/ldm/lumina/controlnet.py | 113 ++++++++++++++++++++++++++++++ comfy/ldm/lumina/model.py | 24 ++++--- comfy_extras/nodes_model_patch.py | 101 +++++++++++++++++++++++++- 3 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 comfy/ldm/lumina/controlnet.py diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py new file mode 100644 index 000000000..fd7ce3b5c --- /dev/null +++ b/comfy/ldm/lumina/controlnet.py @@ -0,0 +1,113 @@ +import torch +from torch import nn + +from .model import JointTransformerBlock + +class ZImageControlTransformerBlock(JointTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + operation_settings=None, + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + +class ZImage_Control(torch.nn.Module): + def __init__( + self, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + multiple_of: int = 256, + ffn_dim_multiplier: float = (8.0 / 3.0), + norm_eps: float = 1e-5, + qk_norm: bool = True, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__() + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.additional_in_dim = 0 + self.control_in_dim = 16 + n_refiner_layers = 2 + self.n_control_layers = 6 + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=i, + operation_settings=operation_settings, + ) + for i in range(self.n_control_layers) + ] + ) + + all_x_embedder = {} + patch_size = 2 + f_patch_size = 1 + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): + patch_size = 2 + f_patch_size = 1 + pH = pW = patch_size + B, C, H, W = control_context.shape + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + x_attn_mask = None + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 070b5da09..f1c1a0ec3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -568,7 +568,7 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -585,16 +585,24 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + patches = transformer_options.get("patches", {}) transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) - freqs_cis = freqs_cis.to(x.device) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + freqs_cis = freqs_cis.to(img.device) - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + for i, layer in enumerate(self.layers): + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + img = self.final_layer(img, adaln_input) + img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -x + return -img diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 783c59b6b..c61810dbf 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -6,6 +6,7 @@ import comfy.ops import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats +import comfy.ldm.lumina.controlnet class BlockWiseControlBlock(torch.nn.Module): @@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module): return embedding +def z_image_convert(sd): + replace_keys = {".attention.to_out.0.bias": ".attention.out.bias", + ".attention.norm_k.weight": ".attention.k_norm.weight", + ".attention.norm_q.weight": ".attention.q_norm.weight", + ".attention.to_out.0.weight": ".attention.out.weight" + } + + out_sd = {} + for k in sorted(sd.keys()): + w = sd[k] + + k_out = k + if k_out.endswith(".attention.to_k.weight"): + cc = [w] + continue + if k_out.endswith(".attention.to_q.weight"): + cc = [w] + cc + continue + if k_out.endswith(".attention.to_v.weight"): + cc = cc + [w] + w = torch.cat(cc, dim=0) + k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") + + for r, rr in replace_keys.items(): + k_out = k_out.replace(r, rr) + out_sd[k_out] = w + + return out_sd + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -211,6 +241,9 @@ class ModelPatchLoader: elif 'feature_embedder.mid_layer_norm.bias' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet + sd = z_image_convert(sd) + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -263,6 +296,69 @@ class DiffSynthCnetPatch: def models(self): return [self.model_patch] +class ZImageControlPatch: + def __init__(self, model_patch, vae, image, strength): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + self.encoded_image = self.encode_latent_cond(image) + self.encoded_image_size = (image.shape[1], image.shape[2]) + self.temp_data = None + + def encode_latent_cond(self, image): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + txt = kwargs.get("txt") + pe = kwargs.get("pe") + vec = kwargs.get("vec") + block_index = kwargs.get("block_index") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + comfy.model_management.load_models_gpu(loaded_models) + + cnet_index = (block_index // 5) + cnet_index_float = (block_index / 5) + + kwargs.pop("img") # we do ops in place + kwargs.pop("txt") + + cnet_blocks = self.model_patch.model.n_control_layers + if cnet_index_float > (cnet_blocks - 1): + self.temp_data = None + return kwargs + + if self.temp_data is None or self.temp_data[0] > cnet_index: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None + + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + self.temp_data = None + return self + + def models(self): + return [self.model_patch] + class QwenImageDiffsynthControlnet: @classmethod def INPUT_TYPES(s): @@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet: mask = mask.unsqueeze(2) mask = 1.0 - mask - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) + if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): + model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + else: + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From 3f512f5659cfbb3c53999cde6ff557591740252b Mon Sep 17 00:00:00 2001 From: Jim Heising Date: Tue, 2 Dec 2025 19:29:27 -0800 Subject: [PATCH 194/231] Added PATCH method to CORS headers (#11066) Added PATCH http method to access-control-allow-header-methods header because there are now PATCH endpoints exposed in the API. See https://github.com/comfyanonymous/ComfyUI/blob/277237ccc1499bac7fcd221a666dfe7a32ac4206/api_server/routes/internal/internal_routes.py#L34 for an example of an API endpoint that uses the PATCH method. --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index e3bd056d9..ac4f42222 100644 --- a/server.py +++ b/server.py @@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str): response = await handler(request) response.headers['Access-Control-Allow-Origin'] = allowed_origin - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Credentials'] = 'true' return response From 73f5649196f472d3719e2e7513e0a9d029cc3e38 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:49:29 +1000 Subject: [PATCH 195/231] Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995) * hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe33 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe33 Initial commit. HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). --- comfy/ldm/hunyuan_video/upsampler.py | 3 +- comfy/ldm/hunyuan_video/vae_refiner.py | 94 +++------ comfy/ldm/modules/diffusionmodules/model.py | 207 ++++++++++++++------ 3 files changed, 174 insertions(+), 130 deletions(-) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 9f5e91a59..85f515f67 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm import model_management, model_patcher class SRResidualCausalBlock3D(nn.Module): diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 9f750dcc4..ddf77cd0e 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,42 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init -class NoPadConv3d(nn.Module): - def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): - super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, x): - return self.conv(x) - - -def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): - - x = xl[0] - xl.clear() - - if conv_carry_out is not None: - to_push = x[:, :, -2:, :, :].clone() - conv_carry_out.append(to_push) - - if isinstance(op, NoPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') - else: - carry_len = conv_carry_in[0].shape[2] - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') - - out = op(x) - - return out - class RMS_norm(nn.Module): def __init__(self, dim): @@ -49,7 +19,7 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -109,7 +79,7 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -163,23 +133,6 @@ class UpSmpl(nn.Module): return h + x -class HunyuanRefinerResnetBlock(ResnetBlock): - def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): - super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - - def forward(self, x, conv_carry_in=None, conv_carry_out=None): - h = x - h = [ self.swish(self.norm1(x)) ] - h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - h = [ self.dropout(self.swish(self.norm2(h))) ] - h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x+h - class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): @@ -191,7 +144,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -206,9 +159,10 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -218,9 +172,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -246,22 +200,20 @@ class Encoder(nn.Module): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None + x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out @@ -288,7 +240,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -298,9 +250,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -308,9 +260,10 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -340,7 +293,7 @@ class Decoder(nn.Module): conv_carry_out = None for stage in self.up: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) @@ -350,10 +303,7 @@ class Decoder(nn.Module): conv_carry_in = conv_carry_out del x - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index de1e01cc8..681a55db5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -89,29 +126,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -127,17 +159,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -183,23 +218,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -520,9 +555,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -535,6 +575,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -561,10 +602,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -590,15 +636,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [ x1 ] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions-1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -607,15 +680,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [ nonlinearity(h) ] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -629,12 +702,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -709,29 +788,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out From c120eee5bacca643062657d2a7efad83c7d4d828 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Dec 2025 21:17:13 -0800 Subject: [PATCH 196/231] Add MatchType, DynamicCombo, and Autogrow support to V3 Schema (#10832) * Added output_matchtypes to generated json for v3, initial backend support for MatchType, created nodes_logic.py and added SwitchNode * Fixed providing list of allowed_types * Add workaround in validation.py for V3 Combo outputs not working as Combo inputs * Make match type receive_type pass validation * Also add MatchType check to input_type in validation - will likely trigger when connecting to non-lazy stuff * Make sure this PR only has MatchType stuff * Initial work on DynamicCombo * Add get_dynamic function, not yet filled out correctly * Mark Switch node as Beta * Make sure other unfinished dynamic types are not accidentally used * Send DynamicCombo.Option inputs in the same format as normal v1 inputs * add dynamic combo test node * Support validation of inputs and outputs * Add missing input params to DynamicCombo.Input * Add get_all function to inputs for id validation purposes * Fix imports for v3 returning everything when doing io/ui/IO/UI instead of what is in __all__ of _io.py and _ui.py * Modifying behavior of get_dynamic in V3 + serialization so can be used in execution code * Fix v3 schema validation code after changes * Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet) * Add nesting of inputs on DynamicCombo during execution * Work with latest frontend commits * Fix cringe arrows * frontend will no longer namespace dynamic inputs widgets so reflect that in code, refactor build_nested_inputs * Prepare Autogrow support for the love of the game * satisfy ruff * Create test nodes for Autogrow to collab with frontend development * Add nested combo to DCTestNode * Remove array support from build_nested_inputs, properly handle missing expected values * Make execution.validate_inputs properly validate required dynamic inputs, renamed dynamic_data to dynamic_paths for clarity * MatchType does not need any DynamicInput/Output features on backend; will increase compatibility with dynamic types * Probably need this for ruff check * Change MatchType to have template be the first and only required param; output id's do nothing right now, so no need * Fix merge regression with LatentUpscaleModel type not being put in __all__ for _io.py, fix invalid type hint for validate_inputs * Make Switch node inputs optional, disallow both inputs from being missing, and still work properly with lazy; when one input is missing, use the other no matter what the switch is set to * Satisfy ruff * Move MatchType code above the types that inherit from DynamicInput * Add DynamicSlot type, awaiting frontend support * Make curr_prefix creation happen in Autogrow, move curr_prefix in DynamicCombo to only be created if input exists in live_inputs * I was confused, fixing accidentally redundant curr_prefix addition in Autogrow * Make sure Autogrow inputs are force_input = True when WidgetInput, fix runtime validation by removing original input from expected inputs, fix min/max bounds, change test nodes slightly * Remove unnecessary id usage in Autogrow test node outputs * Commented out Switch node + test nodes * Remove commented out code from Autogrow * Make TemplatePrefix max more clear, allow max == 1 * Replace all dict[str] with dict[str, Any] * Renamed add_to_dict_live_inputs to expand_schema_for_dynamic * Fixed typo in DynamicSlot input code * note about live_inputs not being present soon in get_v1_info (internal function anyway) * For now, hide DynamicCombo and Autogrow from public interface * Removed comment --- comfy_api/latest/__init__.py | 4 +- comfy_api/latest/_io.py | 416 ++++++++++++++++++++++++++------- comfy_api/latest/_io_public.py | 1 + comfy_api/latest/_ui_public.py | 1 + comfy_api/v0_0_2/__init__.py | 6 +- comfy_execution/validation.py | 6 + comfy_extras/nodes_logic.py | 155 ++++++++++++ execution.py | 40 ++-- nodes.py | 1 + 9 files changed, 525 insertions(+), 105 deletions(-) create mode 100644 comfy_api/latest/_io_public.py create mode 100644 comfy_api/latest/_ui_public.py create mode 100644 comfy_extras/nodes_logic.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 176ae36e0..0fa01d1e7 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL -from . import _io as io -from . import _ui as ui +from . import _io_public as io +from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 79c0722a9..257f07c42 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -4,6 +4,7 @@ import copy import inspect from abc import ABC, abstractmethod from collections import Counter +from collections.abc import Iterable from dataclasses import asdict, dataclass from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING @@ -150,6 +151,9 @@ class _IO_V3: def __init__(self): pass + def validate(self): + pass + @property def io_type(self): return self.Parent.io_type @@ -182,6 +186,9 @@ class Input(_IO_V3): def get_io_type(self): return _StringIOType(self.io_type) + def get_all(self) -> list[Input]: + return [self] + class WidgetInput(Input): ''' Base class for a V3 Input with widget. @@ -814,13 +821,61 @@ class MultiType: else: return super().as_dict() +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): + self.template_id = template_id + # account for syntactic sugar + if not isinstance(allowed_types, Iterable): + allowed_types = [allowed_types] + for t in allowed_types: + if not isinstance(t, type): + if not isinstance(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}") + else: + if not issubclass(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}") + self.allowed_types = allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": ",".join([t.io_type for t in self.allowed_types]), + } + + class Input(Input): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(Output): + def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - @abstractmethod def get_dynamic(self) -> list[Input]: - ... + return [] + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + pass + class DynamicOutput(Output, ABC): ''' @@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC): is_output_list=False): super().__init__(id, display_name, tooltip, is_output_list) - @abstractmethod def get_dynamic(self) -> list[Output]: - ... + return [] @comfytype(io_type="COMFY_AUTOGROW_V3") -class AutogrowDynamic(ComfyTypeI): - Type = list[Any] - class Input(DynamicInput): - def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template_input = template_input - if min is not None: - assert(min >= 1) - if max is not None: - assert(max >= 1) +class Autogrow(ComfyTypeI): + Type = dict[str, Any] + _MaxNames = 100 # NOTE: max 100 names for sanity + + class _AutogrowTemplate: + def __init__(self, input: Input): + # dynamic inputs are not allowed as the template input + assert(not isinstance(input, DynamicInput)) + self.input = copy.copy(input) + if isinstance(self.input, WidgetInput): + self.input.force_input = True + self.names: list[str] = [] + self.cached_inputs = {} + + def _create_input(self, input: Input, name: str): + new_input = copy.copy(self.input) + new_input.id = name + return new_input + + def _create_cached_inputs(self): + for name in self.names: + self.cached_inputs[name] = self._create_input(self.input, name) + + def get_all(self) -> list[Input]: + return list(self.cached_inputs.values()) + + def as_dict(self): + return prune_dict({ + "input": create_input_dict_v1([self.input]), + }) + + def validate(self): + self.input.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + real_inputs = [] + for name, input in self.cached_inputs.items(): + if name in live_inputs: + real_inputs.append(input) + add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, real_inputs, curr_prefix) + + class TemplatePrefix(_AutogrowTemplate): + def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): + super().__init__(input) + self.prefix = prefix + assert(min >= 0) + assert(max >= 1) + assert(max <= Autogrow._MaxNames) self.min = min self.max = max + self.names = [f"{self.prefix}{i}" for i in range(self.max)] + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "prefix": self.prefix, + "min": self.min, + "max": self.max, + }) + + class TemplateNames(_AutogrowTemplate): + def __init__(self, input: Input, names: list[str], min: int=1): + super().__init__(input) + self.names = names[:Autogrow._MaxNames] + assert(min >= 0) + self.min = min + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "names": self.names, + "min": self.min, + }) + + class Input(DynamicInput): + def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) def get_dynamic(self) -> list[Input]: - curr_count = 1 - new_inputs = [] - for i in range(self.min): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = self.optional or new_input.optional - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - # pretend to expand up to max - for i in range(curr_count-1, self.max): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = True - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - return new_inputs + return self.template.get_all() -@comfytype(io_type="COMFY_COMBODYNAMIC_V3") -class ComboDynamic(ComfyTypeI): - class Input(DynamicInput): - def __init__(self, id: str): - pass + def get_all(self) -> list[Input]: + return [self] + self.template.get_all() -@comfytype(io_type="COMFY_MATCHTYPE_V3") -class MatchType(ComfyTypeIO): - class Template: - def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): - self.template_id = template_id - self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + def validate(self): + self.template.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + curr_prefix = f"{curr_prefix}{self.id}." + # need to remove self from expected inputs dictionary; replaced by template inputs in frontend + for inner_dict in d.values(): + if self.id in inner_dict: + del inner_dict[self.id] + self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + +@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") +class DynamicCombo(ComfyTypeI): + Type = dict[str, Any] + + class Option: + def __init__(self, key: str, inputs: list[Input]): + self.key = key + self.inputs = inputs def as_dict(self): return { - "template_id": self.template_id, - "allowed_types": "".join(t.io_type for t in self.allowed_types), + "key": self.key, + "inputs": create_input_dict_v1(self.inputs), } class Input(DynamicInput): - def __init__(self, id: str, template: MatchType.Template, + def __init__(self, id: str, options: list[DynamicCombo.Option], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template = template + self.options = options + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) def get_dynamic(self) -> list[Input]: - return [self] + return [input for option in self.options for input in option.inputs] + + def get_all(self) -> list[Input]: + return [self] + [input for option in self.options for input in option.inputs] def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "options": [o.as_dict() for o in self.options], }) - class Output(DynamicOutput): - def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) - self.template = template + def validate(self): + # make sure all nested inputs are validated + for option in self.options: + for input in option.inputs: + input.validate() - def get_dynamic(self) -> list[Output]: - return [self] +@comfytype(io_type="COMFY_DYNAMICSLOT_V3") +class DynamicSlot(ComfyTypeI): + Type = dict[str, Any] + + class Input(DynamicInput): + def __init__(self, slot: Input, inputs: list[Input], + display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None): + assert(not isinstance(slot, DynamicInput)) + self.slot = copy.copy(slot) + self.slot.display_name = slot.display_name if slot.display_name is not None else display_name + optional = True + self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip + self.slot.lazy = slot.lazy if slot.lazy is not None else lazy + self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict + super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict) + self.inputs = inputs + self.force_input = None + # force widget inputs to have no widgets, otherwise this would be awkward + if isinstance(self.slot, WidgetInput): + self.force_input = True + self.slot.force_input = True + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) + + def get_dynamic(self) -> list[Input]: + return [self.slot] + self.inputs + + def get_all(self) -> list[Input]: + return [self] + [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "slotType": str(self.slot.get_io_type()), + "inputs": create_input_dict_v1(self.inputs), + "forceInput": self.force_input, }) + def validate(self): + self.slot.validate() + for input in self.inputs: + input.validate() + +def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): + dynamic = d.setdefault("dynamic_paths", {}) + if self is not None: + dynamic[self.id] = f"{curr_prefix}{self.id}" + for i in inputs: + if not isinstance(i, DynamicInput): + dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + +class V3Data(TypedDict): + hidden_inputs: dict[str, Any] + dynamic_paths: dict[str, Any] class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -984,6 +1163,7 @@ class NodeInfoV1: output_is_list: list[bool]=None output_name: list[str]=None output_tooltips: list[str]=None + output_matchtypes: list[str]=None name: str=None display_name: str=None description: str=None @@ -1061,7 +1241,11 @@ class Schema: '''Validate the schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' - input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + nested_inputs: list[Input] = [] + if self.inputs is not None: + for input in self.inputs: + nested_inputs.extend(input.get_all()) + input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] input_set = set(input_ids) output_set = set(output_ids) @@ -1077,6 +1261,13 @@ class Schema: issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) + # validate inputs and outputs + if self.inputs is not None: + for input in self.inputs: + input.validate() + if self.outputs is not None: + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1102,19 +1293,10 @@ class Schema: if output.id is None: output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls) -> NodeInfoV1: + def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: + # NOTE: live_inputs will not be used anymore very soon and this will be done another way # get V1 inputs - input = { - "required": {} - } - if self.inputs: - for i in self.inputs: - if isinstance(i, DynamicInput): - dynamic_inputs = i.get_dynamic() - for d in dynamic_inputs: - add_to_dict_v1(d, input) - else: - add_to_dict_v1(i, input) + input = create_input_dict_v1(self.inputs, live_inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1123,12 +1305,24 @@ class Schema: output_is_list = [] output_name = [] output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False if self.outputs: for o in self.outputs: output.append(o.io_type) output_is_list.append(o.is_output_list) output_name.append(o.display_name if o.display_name else o.io_type) output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None info = NodeInfoV1( input=input, @@ -1137,6 +1331,7 @@ class Schema: output_is_list=output_is_list, output_name=output_name, output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, name=self.node_id, display_name=self.display_name, category=self.category, @@ -1182,16 +1377,57 @@ class Schema: return info -def add_to_dict_v1(i: Input, input: dict): +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: + input = { + "required": {} + } + add_to_input_dict_v1(input, inputs, live_inputs) + return input + +def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + if dynamic_dict is None: + value = (i.get_io_type(), as_dict) + else: + value = (i.get_io_type(), as_dict, dynamic_dict) + d.setdefault(key, {})[i.id] = value def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): + paths = v3_data.get("dynamic_paths", None) + if paths is None: + return values + values = values.copy() + result = {} + + for key, path in paths.items(): + parts = path.split(".") + current = result + + for i, p in enumerate(parts): + is_last = (i == len(parts) - 1) + + if is_last: + current[p] = values.pop(key, None) + else: + current = current.setdefault(p, {}) + + values.update(result) + return values class _ComfyNodeBaseInternal(_ComfyNodeInternal): @@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) return type_clone @final @@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls) + info = schema.get_v1_info(cls, live_inputs) input = info.input if not include_hidden: input.pop("hidden", None) if return_schema: - return input, schema + v3_data: V3Data = {} + dynamic = input.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + return input, schema, v3_data return input @final @@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: + def validate_inputs(cls, **kwargs) -> bool | str: """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" raise NotImplementedError @@ -1628,6 +1868,7 @@ __all__ = [ "StyleModel", "Gligen", "UpscaleModel", + "LatentUpscaleModel", "Audio", "Video", "SVG", @@ -1651,6 +1892,10 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + # Dynamic Types + "MatchType", + # "DynamicCombo", + # "Autogrow", # Other classes "HiddenHolder", "Hidden", @@ -1661,4 +1906,5 @@ __all__ = [ "NodeOutput", "add_to_dict_v1", "add_to_dict_v3", + "V3Data", ] diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py new file mode 100644 index 000000000..43c7680f3 --- /dev/null +++ b/comfy_api/latest/_io_public.py @@ -0,0 +1 @@ +from ._io import * # noqa: F403 diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py new file mode 100644 index 000000000..85b11d78b --- /dev/null +++ b/comfy_api/latest/_ui_public.py @@ -0,0 +1 @@ +from ._ui import * # noqa: F403 diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index de0f95001..c4fa1d971 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,7 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 +from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -42,4 +42,8 @@ __all__ = [ "InputImpl", "Types", "ComfyExtension", + "io", + "IO", + "ui", + "UI", ] diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index cec105fc9..24c0b4ed7 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,5 @@ from __future__ import annotations +from comfy_api.latest import IO def validate_node_input( @@ -23,6 +24,11 @@ def validate_node_input( if not received_type != input_type: return True + # If the received type or input_type is a MatchType, we can return True immediately; + # validation for this is handled by the frontend + if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..95a6ba788 --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,155 @@ +from typing import TypedDict +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import _io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True, optional=True), + io.MatchType.Input("on_true", template=template, lazy=True, optional=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=..., on_true=...): + # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + # This trick allows us to ignore the value of the switch and still be able to run execute(). + + # One of the inputs may be missing, in which case we need to evaluate the other input + if on_false is ...: + return ["on_true"] + if on_true is ...: + return ["on_false"] + # Normal lazy switch operation + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def validate_inputs(cls, switch, on_false=..., on_true=...): + # This check happens before check_lazy_status(), so we can eliminate the case where + # both inputs are missing. + if on_false is ... and on_true is ...: + return "At least one of on_false or on_true must be connected to Switch node" + return True + + @classmethod + def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: + if on_true is ...: + return io.NodeOutput(on_false) + if on_false is ...: + return io.NodeOutput(on_true) + return io.NodeOutput(on_true if switch else on_false) + + +class DCTestNode(io.ComfyNode): + class DCValues(TypedDict): + combo: str + string: str + integer: int + image: io.Image.Type + subcombo: dict[str] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DCTestNode", + display_name="DCTest", + category="logic", + is_output_node=True, + inputs=[_io.DynamicCombo.Input("combo", options=[ + _io.DynamicCombo.Option("option1", [io.String.Input("string")]), + _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + _io.DynamicCombo.Option("option4", [ + _io.DynamicCombo.Input("subcombo", options=[ + _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + ]) + ])] + )], + outputs=[io.AnyType.Output()], + ) + + @classmethod + def execute(cls, combo: DCValues) -> io.NodeOutput: + combo_val = combo["combo"] + if combo_val == "option1": + return io.NodeOutput(combo["string"]) + elif combo_val == "option2": + return io.NodeOutput(combo["integer"]) + elif combo_val == "option3": + return io.NodeOutput(combo["image"]) + elif combo_val == "option4": + return io.NodeOutput(f"{combo['subcombo']}") + else: + raise ValueError(f"Invalid combo: {combo_val}") + + +class AutogrowNamesTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"]) + return io.Schema( + node_id="AutogrowNamesTestNode", + display_name="AutogrowNamesTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class AutogrowPrefixTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10) + return io.Schema( + node_id="AutogrowPrefixTestNode", + display_name="AutogrowPrefixTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # SwitchNode, + # DCTestNode, + # AutogrowNamesTestNode, + # AutogrowPrefixTestNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/execution.py b/execution.py index 17c77beab..c2186ac98 100644 --- a/execution.py +++ b/execution.py @@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func -from comfy_api.latest import io +from comfy_api.latest import io, _io class ExecutionResult(Enum): @@ -76,7 +76,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} if is_v3: - valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) else: valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys, hidden_inputs_v3 + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data map_node_over_list = None #Don't hook this please @@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) # otherwise, use class instance to populate/reuse some fields else: type_obj = type(obj) type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -320,8 +325,8 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated): class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) - errors = [] valid = True validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): + class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: + class_inputs = obj_class.INPUT_TYPES() validate_function_name = "VALIDATE_INPUTS" validate_function = getattr(obj_class, validate_function_name, None) if validate_function is not None: @@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated): validate_has_kwargs = argspec.varkw is not None received_types = {} + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + for x in valid_inputs: input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None @@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): diff --git a/nodes.py b/nodes.py index 4c910a34b..356aa63df 100644 --- a/nodes.py +++ b/nodes.py @@ -2355,6 +2355,7 @@ async def init_builtin_extra_nodes(): "nodes_easycache.py", "nodes_audio_encoder.py", "nodes_rope.py", + "nodes_logic.py", "nodes_nop.py", ] From 861817d22d2659099811b56005c9eaea18d64c73 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:47:51 -0800 Subject: [PATCH 197/231] Fix issue with portable updater. (#11070) This should fix the problem with the portable updater not working with portables created from a separate branch on the repo. This does not affect any current portables who were all created on the master branch. --- .ci/update_windows/update.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 51a263203..59ece5130 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -66,8 +66,10 @@ if branch is None: try: ref = repo.lookup_reference('refs/remotes/origin/master') except: - print("pulling.") # noqa: T201 - pull(repo) + print("fetching.") # noqa: T201 + for remote in repo.remotes: + if remote.name == "origin": + remote.fetch() ref = repo.lookup_reference('refs/remotes/origin/master') repo.checkout(ref) branch = repo.lookup_branch('master') @@ -149,3 +151,4 @@ try: shutil.copy(stable_update_script, stable_update_script_to) except: pass + From 519c9411653df99761053c30e101816e0ca3c24b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:28:45 +1000 Subject: [PATCH 198/231] Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069) * mp: only count the offload cost of math once This was previously bundling the combined weight storage and computation cost * ops: put all post async transfer compute on the main stream Some models have massive weights that need either complex dequantization or lora patching. Don't do these patchings on the offload stream, instead do them on the main stream to syncrhonize the potentially large vram spikes for these compute processes. This avoids having to assume a worst case scenario of multiple offload streams all spiking VRAM is parallel with whatever the main stream is doing. --- comfy/model_patcher.py | 4 ++-- comfy/ops.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: From 19f2192d69d13445131b72ad1d87167f59b66fc4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:37:35 +0200 Subject: [PATCH 199/231] fix(V3-Schema): use empty list defaults for Schema.inputs/outputs/hidden to avoid None issues (#11083) --- comfy_api/latest/_io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 257f07c42..866c3e0eb 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -5,7 +5,7 @@ import inspect from abc import ABC, abstractmethod from collections import Counter from collections.abc import Iterable -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing_extensions import NotRequired, final @@ -1199,9 +1199,9 @@ class Schema: """Display name of node.""" category: str = "sd" """The category of the node, as per the "Add Node" menu.""" - inputs: list[Input]=None - outputs: list[Output]=None - hidden: list[Hidden]=None + inputs: list[Input] = field(default_factory=list) + outputs: list[Output] = field(default_factory=list) + hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False From 87c104bfc1928f0b018a50f5867f425e10482929 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:55:44 +0200 Subject: [PATCH 200/231] add support for "@image" reference format in Kling Omni API nodes (#11082) --- comfy_api_nodes/apis/kling_api.py | 30 +++++-- comfy_api_nodes/nodes_kling.py | 138 ++++++++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 0a3b447c5..d8949f8ac 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel): url: str | None = Field(None, description="URL for generated video") -class TaskStatusVideoResults(BaseModel): +class TaskStatusImageResult(BaseModel): + index: int = Field(..., description="Image Number,0-9") + url: str = Field(..., description="URL for generated image") + + +class OmniTaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) + images: list[TaskStatusImageResult] | None = Field(None) -class TaskStatusVideoResponseData(BaseModel): +class OmniTaskStatusResponseData(BaseModel): created_at: int | None = Field(None, description="Task creation time") updated_at: int | None = Field(None, description="Task update time") task_status: str | None = None task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_id: str | None = Field(None, description="Task ID") - task_result: TaskStatusVideoResults | None = Field(None) + task_result: OmniTaskStatusResults | None = Field(None) -class TaskStatusVideoResponse(BaseModel): +class OmniTaskStatusResponse(BaseModel): code: int | None = Field(None, description="Error code") message: str | None = Field(None, description="Error message") request_id: str | None = Field(None, description="Request ID") - data: TaskStatusVideoResponseData | None = Field(None) + data: OmniTaskStatusResponseData | None = Field(None) + + +class OmniImageParamImage(BaseModel): + image: str = Field(...) + + +class OmniProImageRequest(BaseModel): + model_name: str = Field(..., description="kling-image-o1") + resolution: str = Field(..., description="'1k' or '2k'") + aspect_ratio: str | None = Field(...) + prompt: str = Field(...) + mode: str = Field("pro") + n: int | None = Field(1, le=9) + image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 850c44db6..6c840dc47 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere import logging import math +import re import torch from typing_extensions import override @@ -49,12 +50,14 @@ from comfy_api_nodes.apis import ( KlingSingleImageEffectModelName, ) from comfy_api_nodes.apis.kling_api import ( + OmniImageParamImage, OmniParamImage, OmniParamVideo, OmniProFirstLastFrameRequest, + OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, - TaskStatusVideoResponse, + OmniTaskStatusResponse, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -210,7 +213,36 @@ VOICES_CONFIG = { } -async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: +def normalize_omni_prompt_references(prompt: str) -> str: + """ + Rewrites Kling Omni-style placeholders used in the app, like: + + @image, @image1, @image2, ... @imageN + @video, @video1, @video2, ... @videoN + + into the API-compatible form: + + <<>>, <<>>, ... + <<>>, <<>>, ... + + This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app. + """ + if not prompt: + return prompt + + def _image_repl(match): + return f"<<>>" + + def _video_repl(match): + return f"<<>>" + + # (? and not @imageFoo + prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt) + return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) + + +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: if response.code: raise RuntimeError( f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" @@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProText2VideoRequest( model_name=model_name, prompt=prompt, @@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") @@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProFirstLastFrameRequest( model_name=model_name, prompt=prompt, @@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): duration: int, reference_images: Input.Image, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if get_number_of_images(reference_images) > 7: raise ValueError("The maximum number of reference images is 7.") @@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode): return await finish_omni_video_task(cls, response) +class OmniProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageNode", + display_name="Kling Omni Image (Pro)", + category="api node/image/Kling", + description="Create or edit images with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the image content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], + ), + IO.Image.Input( + "reference_images", + tooltip="Up to 10 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + resolution: str, + aspect_ratio: str, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + image_list: list[OmniImageParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 10: + raise ValueError("The maximum number of reference images is 10.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniImageParamImage(image=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), + response_model=OmniTaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension): OmniProImageToVideoNode, OmniProVideoToVideoNode, OmniProEditVideoNode, + # OmniProImageNode, # need support from backend ] From 440268d3940eb14a01595439bbc05c4aacde9c72 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 3 Dec 2025 23:52:31 +0200 Subject: [PATCH 201/231] convert nodes_load_3d.py to V3 schema (#10990) --- comfy_api/latest/_ui.py | 13 +++- comfy_extras/nodes_load_3d.py | 127 ++++++++++++++++------------------ 2 files changed, 71 insertions(+), 69 deletions(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index b0bbabe2a..6d1bea599 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import os import random +import uuid from io import BytesIO from typing import Type @@ -436,9 +437,19 @@ class PreviewUI3D(_UIOutput): def __init__(self, model_file, camera_info, **kwargs): self.model_file = model_file self.camera_info = camera_info + self.bg_image_path = None + bg_image = kwargs.get("bg_image", None) + if bg_image is not None: + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = PILImage.fromarray(img_array) + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + self.bg_image_path = f"temp/{filename}" def as_dict(self): - return {"result": [self.model_file, self.camera_info]} + return {"result": [self.model_file, self.camera_info, self.bg_image_path]} class PreviewText(_UIOutput): diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 54c66ef68..545588ef8 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -2,22 +2,18 @@ import nodes import folder_paths import os -from comfy.comfy_types import IO -from comfy_api.input_impl import VideoFromFile +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, InputImpl, UI from pathlib import Path -from PIL import Image -import numpy as np - -import uuid def normalize_path(path): return path.replace('\\', '/') -class Load3D(): +class Load3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = os.path.join(folder_paths.get_input_directory(), "3d") os.makedirs(input_dir, exist_ok=True) @@ -30,23 +26,29 @@ class Load3D(): for file_path in input_path.rglob("*") if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} ] + return IO.Schema( + node_id="Load3D", + display_name="Load 3D & Animation", + category="3d", + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("image"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output(display_name="mask"), + IO.String.Output(display_name="mesh_path"), + IO.Image.Output(display_name="normal"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Video.Output(display_name="recording_video"), + ], + ) - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): + @classmethod + def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) normal_path = folder_paths.get_annotated_filepath(image['normal']) @@ -61,58 +63,47 @@ class Load3D(): if image['recording'] != "": recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - video = VideoFromFile(recording_video_path) + video = InputImpl.VideoFromFile(recording_video_path) - return output_image, output_mask, model_file, normal_image, image['camera_info'], video + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) -class Preview3D(): + process = execute # TODO: remove + + +class Preview3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}), - "bg_image": ("IMAGE", {}) - }} + def define_schema(cls): + return IO.Schema( + node_id="Preview3D", + display_name="Preview 3D & Animation", + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.String.Input("model_file", default="", multiline=False), + IO.Load3DCamera.Input("camera_info", optional=True), + IO.Image.Input("bg_image", optional=True), + ], + outputs=[], + ) - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): + @classmethod + def execute(cls, model_file, **kwargs) -> IO.NodeOutput: camera_info = kwargs.get("camera_info", None) bg_image = kwargs.get("bg_image", None) + return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) - bg_image_path = None - if bg_image is not None: + process = execute # TODO: remove - img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) - img = Image.fromarray(img_array) - temp_dir = folder_paths.get_temp_directory() - filename = f"bg_{uuid.uuid4().hex}.png" - bg_image_path = os.path.join(temp_dir, filename) - img.save(bg_image_path, compress_level=1) +class Load3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Load3D, + Preview3D, + ] - bg_image_path = f"temp/{filename}" - return { - "ui": { - "result": [model_file, camera_info, bg_image_path] - } - } - -NODE_CLASS_MAPPINGS = { - "Load3D": Load3D, - "Preview3D": Preview3D, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D & Animation", - "Preview3D": "Preview 3D & Animation", -} +async def comfy_entrypoint() -> Load3DExtension: + return Load3DExtension() From dce518c2b4f99634b5fdde1924d9b0bd468fe1ce Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 4 Dec 2025 03:35:04 +0200 Subject: [PATCH 202/231] convert nodes_audio.py to V3 schema (#10798) --- comfy_api/latest/_ui.py | 9 +- comfy_extras/nodes_audio.py | 744 ++++++++++++++++++------------------ 2 files changed, 382 insertions(+), 371 deletions(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6d1bea599..5a75a3aae 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -319,9 +319,10 @@ class AudioSaveHelper: for key, value in metadata.items(): output_container.metadata[key] = value + layout = "mono" if waveform.shape[0] == 1 else "stereo" # Set up the output stream with appropriate properties if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) if quality == "64k": out_stream.bit_rate = 64000 elif quality == "96k": @@ -333,7 +334,7 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) if quality == "V0": # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool out_stream.codec_context.qscale = 1 @@ -342,12 +343,12 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 else: # format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) frame = av.AudioFrame.from_ndarray( waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format="flt", - layout="mono" if waveform.shape[0] == 1 else "stereo", + layout=layout, ) frame.sample_rate = sample_rate frame.pts = 0 diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 2ed7e0b22..812301fb7 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -6,65 +6,80 @@ import torch import comfy.model_management import folder_paths import os -import io -import json -import random import hashlib import node_helpers import logging -from comfy.cli_args import args -from comfy.comfy_types import FileLocator +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI -class EmptyLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples":latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples":latent, "type": "audio"}) -class ConditioningStableAudio: + generate = execute # TODO: remove + + +class ConditioningStableAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - }} + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "append" - - CATEGORY = "conditioning" - - def append(self, positive, negative, seconds_start, seconds_total): + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) - return (positive, negative) + return IO.NodeOutput(positive, negative) -class VAEEncodeAudio: + append = execute # TODO: remove + + +class VAEEncodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent/audio" - - def encode(self, vae, audio): + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] if 44100 != sample_rate: waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) @@ -72,213 +87,134 @@ class VAEEncodeAudio: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples":t}, ) + return IO.NodeOutput({"samples":t}) -class VAEDecodeAudio: + encode = execute # TODO: remove + + +class VAEDecodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} - RETURN_TYPES = ("AUDIO",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "latent/audio" - - def decode(self, vae, samples): + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100}, ) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + + decode = execute # TODO: remove -def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): - - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results: list[FileLocator] = [] - - # Prepare metadata dictionary - metadata = {} - if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - # Opus supported sample rates - OPUS_RATES = [8000, 12000, 16000, 24000, 48000] - - for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" - output_path = os.path.join(full_output_folder, file) - - # Use original sample rate initially - sample_rate = audio["sample_rate"] - - # Handle Opus sample rate requirements - if format == "opus": - if sample_rate > 48000: - sample_rate = 48000 - elif sample_rate not in OPUS_RATES: - # Find the next highest supported rate - for rate in sorted(OPUS_RATES): - if rate > sample_rate: - sample_rate = rate - break - if sample_rate not in OPUS_RATES: # Fallback if still not supported - sample_rate = 48000 - - # Resample if necessary - if sample_rate != audio["sample_rate"]: - waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) - - # Create output with specified format - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format=format) - - # Set metadata on the container - for key, value in metadata.items(): - output_container.metadata[key] = value - - layout = 'mono' if waveform.shape[0] == 1 else 'stereo' - # Set up the output stream with appropriate properties - if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) - if quality == "64k": - out_stream.bit_rate = 64000 - elif quality == "96k": - out_stream.bit_rate = 96000 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "192k": - out_stream.bit_rate = 192000 - elif quality == "320k": - out_stream.bit_rate = 320000 - elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) - if quality == "V0": - #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool - out_stream.codec_context.qscale = 1 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "320k": - out_stream.bit_rate = 320000 - else: #format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout) - frame.sample_rate = sample_rate - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - - # Flush encoder - output_container.mux(out_stream.encode(None)) - - # Close containers - output_container.close() - - # Write the output to file - output_buffer.seek(0) - with open(output_path, 'wb') as f: - f.write(output_buffer.getbuffer()) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - return { "ui": { "audio": results } } - -class SaveAudio: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + display_name="Save Audio (FLAC)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) - RETURN_TYPES = () - FUNCTION = "save_flac" + save_flac = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) - -class SaveAudioMP3: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + display_name="Save Audio (MP3)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["V0", "128k", "320k"], {"default": "V0"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_mp3" + save_mp3 = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class SaveAudioOpus: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_opus" + save_opus = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class PreviewAudio(SaveAudio): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"audio": ("AUDIO", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + + save_flac = execute # TODO: remove + def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format.""" @@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]: wav = f32_pcm(wav) return wav, sr -class LoadAudio: +class LoadAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + return IO.Schema( + node_id="LoadAudio", + display_name="Load Audio", + category="audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) @classmethod - def IS_CHANGED(s, audio): + def fingerprint_inputs(cls, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -343,46 +283,69 @@ class LoadAudio: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): + def validate_inputs(cls, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True -class RecordAudio: + load = execute # TODO: remove + + +class RecordAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO_RECORD", {})}} + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) + + load = execute # TODO: remove -class TrimAudioDuration: +class TrimAudioDuration(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio": ("AUDIO",), - "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "trim" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Trim audio tensor into chosen time range." - - def trim(self, audio, start_index, duration): + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] audio_length = waveform.shape[-1] @@ -399,23 +362,30 @@ class TrimAudioDuration: if start_frame >= end_frame: raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") - return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove -class SplitAudioChannels: +class SplitAudioChannels(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - }} + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) - RETURN_TYPES = ("AUDIO", "AUDIO") - RETURN_NAMES = ("left", "right") - FUNCTION = "separate" - CATEGORY = "audio" - DESCRIPTION = "Separates the audio into left and right channels." - - def separate(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] @@ -425,7 +395,9 @@ class SplitAudioChannels: left_channel = waveform[..., 0:1, :] right_channel = waveform[..., 1:2, :] - return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): @@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_ return waveform_1, waveform_2, output_sample_rate -class AudioConcat: +class AudioConcat(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "concat" - CATEGORY = "audio" - DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." - - def concat(self, audio1, audio2, direction): + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -477,26 +457,33 @@ class AudioConcat: elif direction == 'before': concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) - return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove -class AudioMerge: +class AudioMerge(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "merge" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." - - def merge(self, audio1, audio2, merge_method): + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -530,85 +517,108 @@ class AudioMerge: if max_val > 1.0: waveform = waveform / max_val - return ({"waveform": waveform, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove -class AudioAdjustVolume: +class AudioAdjustVolume(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "adjust_volume" - CATEGORY = "audio" - - def adjust_volume(self, audio, volume): + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: if volume == 0: - return (audio,) + return IO.NodeOutput(audio) waveform = audio["waveform"] sample_rate = audio["sample_rate"] gain = 10 ** (volume / 20) waveform = waveform * gain - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove -class EmptyAudio: +class EmptyAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), - "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), - "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Float.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + ), + IO.Float.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + ), + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "create_empty_audio" - CATEGORY = "audio" - - def create_empty_audio(self, duration, sample_rate, channels): + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: num_samples = int(round(duration * sample_rate)) waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "EmptyLatentAudio": EmptyLatentAudio, - "VAEEncodeAudio": VAEEncodeAudio, - "VAEDecodeAudio": VAEDecodeAudio, - "SaveAudio": SaveAudio, - "SaveAudioMP3": SaveAudioMP3, - "SaveAudioOpus": SaveAudioOpus, - "LoadAudio": LoadAudio, - "PreviewAudio": PreviewAudio, - "ConditioningStableAudio": ConditioningStableAudio, - "RecordAudio": RecordAudio, - "TrimAudioDuration": TrimAudioDuration, - "SplitAudioChannels": SplitAudioChannels, - "AudioConcat": AudioConcat, - "AudioMerge": AudioMerge, - "AudioAdjustVolume": AudioAdjustVolume, - "EmptyAudio": EmptyAudio, -} +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "EmptyLatentAudio": "Empty Latent Audio", - "VAEEncodeAudio": "VAE Encode Audio", - "VAEDecodeAudio": "VAE Decode Audio", - "PreviewAudio": "Preview Audio", - "LoadAudio": "Load Audio", - "SaveAudio": "Save Audio (FLAC)", - "SaveAudioMP3": "Save Audio (MP3)", - "SaveAudioOpus": "Save Audio (Opus)", - "RecordAudio": "Record Audio", - "TrimAudioDuration": "Trim Audio Duration", - "SplitAudioChannels": "Split Audio Channels", - "AudioConcat": "Audio Concat", - "AudioMerge": "Audio Merge", - "AudioAdjustVolume": "Audio Adjust Volume", - "EmptyAudio": "Empty Audio", -} +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() From ecdc8697d53919a9178bf53ef327a110582db8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:49:28 -0800 Subject: [PATCH 203/231] Qwen Image Lora training fix from #11090 (#11094) --- comfy_extras/nodes_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index cb24ab709..19b8baaf4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode): noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents, From ea17add3c62197b10fd0b71d9169d339adc55c47 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:15:15 -0800 Subject: [PATCH 204/231] Fix case where text encoders where running on the CPU instead of GPU. (#11095) --- comfy/sd.py | 2 ++ comfy/sd1_clip.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5..734bd2845 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -193,6 +193,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -240,6 +241,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3db..503a51843 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: @@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None From 6be85c7920224b45bbc6417e00147815e78c12a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:28:44 +1000 Subject: [PATCH 205/231] mp: use look-ahead actuals for stream offload VRAM calculation (#11096) TIL that the WAN TE has a 2GB weight followed by 16MB as the next size down. This means that team 8GB VRAM would fully offload the TE in async offload mode as it just multiplied this giant size my the num streams. Do the more complex logic of summing up the upcoming to-load weight sizes to avoid triple counting this massive weight. partial unload does the converse of recording the NS most recent unloads as they go. --- comfy/model_patcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index df2d8e827..3dcac3eef 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,12 +699,12 @@ class ModelPatcher: offloaded = [] offload_buffer = 0 loading.sort(reverse=True) - for x in loading: + for i, x in enumerate(loading): module_offload_mem, module_mem, n, m, params = x lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) + potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]])) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -876,14 +876,18 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + if len(unload_list) > 0: + NS = comfy.model_management.NUM_STREAMS + offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -935,6 +939,8 @@ class ModelPatcher: m.comfy_patched_weights = False memory_freed += module_mem offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) logging.debug("freed {}".format(n)) for param in params: From f4bdf5f8302ef10db99644a8672e614ddb29c473 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:04 +1000 Subject: [PATCH 206/231] sd: revise hy VAE VRAM (#11105) This was recently collapsed down to rolling VAE through temporal. Clamp The time dimension. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 734bd2845..fe4dd65f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -483,8 +483,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + #This is likely to significantly over-estimate with single image or low frame counts as the + #implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) From 9bc893c5bbd2838bdd15ebd40e3a3e548ce3e4f0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:36 +1000 Subject: [PATCH 207/231] sd: bump HY1.5 VAE estimate (#11107) Im able to push vram above estimate on partial unload. Bump the estimate. This is experimentally determined with a 720P and 480P datapoint calibrating for 24GB VRAM total. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index fe4dd65f8..03bdb33d5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -471,7 +471,7 @@ class VAE: decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True From 3c8456223c5f6a41af7d99219b391c8c58acb552 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 5 Dec 2025 00:05:28 +0200 Subject: [PATCH 208/231] [API Nodes]: fixes and refactor (#11104) * chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder * chore(api-nodes): add validate_video_frame_count function from LTX PR * chore(api-nodes): replace deprecated V1 imports * fix(api-nodes): the types returned by the "poll_op" function are now correct. --- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/_helpers.py | 14 +-- comfy_api_nodes/util/client.py | 145 ++++++++++++----------- comfy_api_nodes/util/conversions.py | 21 ++-- comfy_api_nodes/util/download_helpers.py | 20 ++-- comfy_api_nodes/util/request_logger.py | 2 - comfy_api_nodes/util/upload_helpers.py | 16 ++- comfy_api_nodes/util/validation_utils.py | 61 ++++++---- 8 files changed, 146 insertions(+), 135 deletions(-) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 80292fb3c..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -47,6 +47,7 @@ from .validation_utils import ( validate_string, validate_video_dimensions, validate_video_duration, + validate_video_frame_count, ) __all__ = [ @@ -94,6 +95,7 @@ __all__ = [ "validate_string", "validate_video_dimensions", "validate_video_duration", + "validate_video_frame_count", # Misc functions "get_fs_object_size", ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 328fe5227..491e6b6a8 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -2,8 +2,8 @@ import asyncio import contextlib import os import time +from collections.abc import Callable from io import BytesIO -from typing import Callable, Optional, Union from comfy.cli_args import args from comfy.model_management import processing_interrupted @@ -35,12 +35,12 @@ def default_base_url() -> str: async def sleep_with_interrupt( seconds: float, - node_cls: Optional[type[IO.ComfyNode]], - label: Optional[str] = None, - start_ts: Optional[float] = None, - estimated_total: Optional[int] = None, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, *, - display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, ): """ Sleep in 1s slices while: @@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: +def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf01d7d36..bf37cba5f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -4,10 +4,11 @@ import json import logging import time import uuid +from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Literal, TypeVar from urllib.parse import urljoin, urlparse import aiohttp @@ -37,8 +38,8 @@ class ApiEndpoint: path: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", *, - query_params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, ): self.path = path self.method = method @@ -52,18 +53,18 @@ class _RequestConfig: endpoint: ApiEndpoint timeout: float content_type: str - data: Optional[dict[str, Any]] - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] - multipart_parser: Optional[Callable] + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None max_retries: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" monitor_progress: bool = True - estimated_total: Optional[int] = None - final_label_on_success: Optional[str] = "Completed" - progress_origin_ts: Optional[float] = None - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None @dataclass @@ -71,10 +72,10 @@ class _PollUIState: started: float status_label: str = "Queued" is_queued: bool = True - price: Optional[float] = None - estimated_duration: Optional[int] = None + price: float | None = None + estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: Optional[float] = None # start time of current active interval (None if queued) + active_since: float | None = None # start time of current active interval (None if queued) _RETRY_STATUS = {408, 429, 500, 502, 503, 504} @@ -87,20 +88,20 @@ async def sync_op( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - response_model: Type[M], - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - data: Optional[BaseModel] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, ) -> M: raw = await sync_op_raw( @@ -131,22 +132,22 @@ async def poll_op( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - response_model: Type[M], - status_extractor: Callable[[M], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[M], Optional[int]]] = None, - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[BaseModel] = None, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> M: raw = await poll_op_raw( @@ -178,22 +179,22 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, + estimated_duration: int | None = None, as_binary: bool = False, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, -) -> Union[dict[str, Any], bytes]: +) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). @@ -229,21 +230,21 @@ async def poll_op_raw( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> dict[str, Any]: """ @@ -261,7 +262,7 @@ async def poll_op_raw( consumed_attempts = 0 # counts only non-queued polls progress_bar = utils.ProgressBar(100) if progress_extractor else None - last_progress: Optional[int] = None + last_progress: int | None = None state = _PollUIState(started=started, estimated_duration=estimated_duration) stop_ticker = asyncio.Event() @@ -420,10 +421,10 @@ async def poll_op_raw( def _display_text( node_cls: type[IO.ComfyNode], - text: Optional[str], + text: str | None, *, - status: Optional[Union[str, int]] = None, - price: Optional[float] = None, + status: str | int | None = None, + price: float | None = None, ) -> None: display_lines: list[str] = [] if status: @@ -440,13 +441,13 @@ def _display_text( def _display_time_progress( node_cls: type[IO.ComfyNode], - status: Optional[Union[str, int]], + status: str | int | None, elapsed_seconds: int, - estimated_total: Optional[int] = None, + estimated_total: int | None = None, *, - price: Optional[float] = None, - is_queued: Optional[bool] = None, - processing_elapsed_seconds: Optional[int] = None, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str: def _snapshot_request_body_for_logging( content_type: str, method: str, - data: Optional[dict[str, Any]], - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], -) -> Optional[Union[dict[str, Any], str]]: + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: if method.upper() == "GET": return None if content_type == "multipart/form-data": @@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): attempt = 0 delay = cfg.retry_delay operation_succeeded: bool = False - final_elapsed_seconds: Optional[int] = None - extracted_price: Optional[float] = None + final_elapsed_seconds: int | None = None + extracted_price: float | None = None while True: attempt += 1 stop_event = asyncio.Event() - monitor_task: Optional[asyncio.Task] = None - sess: Optional[aiohttp.ClientSession] = None + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) @@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): ) -def _validate_or_raise(response_model: Type[M], payload: Any) -> M: +def _validate_or_raise(response_model: type[M], payload: Any) -> M: try: return response_model.model_validate(payload) except Exception as e: @@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _wrap_model_extractor( - response_model: Type[M], - extractor: Optional[Callable[[M], Any]], -) -> Optional[Callable[[dict[str, Any]], Any]]: + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: """Wrap a typed extractor so it can be used by the dict-based poller. Validates the dict into `response_model` before invoking `extractor`. Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating @@ -929,10 +930,10 @@ def _wrap_model_extractor( return _wrapped -def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: if not values: return set() - out: set[Union[str, int]] = set() + out: set[str | int] = set() for v in values: nv = _normalize_status_value(v) if nv is not None: @@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio return out -def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: +def _normalize_status_value(val: str | int | None) -> str | int | None: if isinstance(val, str): return val.strip().lower() return val diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 971dc57de..c57457580 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -4,7 +4,6 @@ import math import mimetypes import uuid from io import BytesIO -from typing import Optional import av import numpy as np @@ -12,8 +11,7 @@ import torch from PIL import Image from comfy.utils import common_upscale -from comfy_api.latest import Input, InputImpl -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import Input, InputImpl, Types from ._helpers import mimetype_to_extension @@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: Optional[str] = None, + name: str | None = None, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co def video_to_base64_string( video: Input.Video, - container_format: VideoContainer = None, - codec: VideoCodec = None + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, ) -> str: """ Converts a video input to a base64 string. @@ -189,12 +187,11 @@ def video_to_base64_string( codec: Optional codec to use (defaults to video.codec if available) """ video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) video_bytes_io.seek(0) return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 14207dc68..3e0d0352d 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -3,15 +3,15 @@ import contextlib import uuid from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO from urllib.parse import urljoin, urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError -from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl from . import request_logger from ._helpers import ( @@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + dest: BytesIO | IO[bytes] | str | Path | None, *, - timeout: Optional[float] = None, + timeout: float | None = None, max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, @@ -71,10 +71,10 @@ async def download_url_to_bytesio( is_path_sink = isinstance(dest, (str, Path)) fhandle = None - session: Optional[aiohttp.ClientSession] = None - stop_evt: Optional[asyncio.Event] = None - monitor_task: Optional[asyncio.Task] = None - req_task: Optional[asyncio.Task] = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None try: with contextlib.suppress(Exception): @@ -234,11 +234,11 @@ async def download_url_to_video_output( timeout: float = None, max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, -) -> VideoFromFile: +) -> InputImpl.VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) - return VideoFromFile(result) + return InputImpl.VideoFromFile(result) async def download_url_as_bytesio( diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index ac52e2eab..e0cb4428d 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import datetime import hashlib import json diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 0532bea9a..b8d33f4d1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,15 +4,13 @@ import logging import time import uuid from io import BytesIO -from typing import Optional from urllib.parse import urlparse import aiohttp import torch from pydantic import BaseModel, Field -from comfy_api.latest import IO, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, Input, Types from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt @@ -32,7 +30,7 @@ from .conversions import ( class UploadRequest(BaseModel): file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( + content_type: str | None = Field( None, description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", ) @@ -56,7 +54,7 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ - # if batch, try to upload each file if max_images is greater than 0 + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 @@ -100,9 +98,9 @@ async def upload_video_to_comfyapi( cls: type[IO.ComfyNode], video: Input.Video, *, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, wait_label: str | None = "Uploading", ) -> str: """ @@ -220,7 +218,7 @@ async def upload_file( return monitor_task = asyncio.create_task(_monitor()) - sess: Optional[aiohttp.ClientSession] = None + sess: aiohttp.ClientSession | None = None try: try: request_logger.log_request_response( diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ec7006aed..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,9 +1,7 @@ import logging -from typing import Optional import torch -from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -37,8 +35,8 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: @@ -54,8 +52,8 @@ def validate_image_aspect_ratio( def validate_images_aspect_ratio_closeness( first_image: torch.Tensor, second_image: torch.Tensor, - min_rel: float, # e.g. 0.8 - max_rel: float, # e.g. 1.25 + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness( def validate_aspect_ratio_string( aspect_ratio: str, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -97,10 +95,10 @@ def validate_aspect_ratio_string( def validate_video_dimensions( video: Input.Video, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -120,8 +118,8 @@ def validate_video_dimensions( def validate_video_duration( video: Input.Video, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -136,6 +134,23 @@ def validate_video_duration( raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") + + def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 @@ -144,8 +159,8 @@ def get_number_of_images(images): def validate_audio_duration( audio: Input.Audio, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ) -> None: sr = int(audio["sample_rate"]) dur = int(audio["waveform"].shape[-1]) / sr @@ -177,7 +192,7 @@ def validate_string( ) -def validate_container_format_is_mp4(video: VideoInput) -> None: +def validate_container_format_is_mp4(video: Input.Video) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: @@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float: def _assert_ratio_bounds( ar: float, *, - min_ratio: Optional[tuple[float, float]] = None, - max_ratio: Optional[tuple[float, float]] = None, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, strict: bool = True, ) -> None: """Validate a numeric aspect ratio against optional min/max ratio bounds.""" From 35fa091340c60612dfb71cb6822dc23b99a5dac2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:52:09 -0800 Subject: [PATCH 209/231] Forgot to put this in README. (#11112) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 91fb510e1..ed857df9f 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) + - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) From 0ec05b1481d12b299bc945dbd407b773cfb66483 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 5 Dec 2025 11:05:38 -0800 Subject: [PATCH 210/231] Remove line made unnecessary (and wrong) after transformer_options was added to NextDiT's _forward definition (#11118) --- comfy/ldm/lumina/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f1c1a0ec3..6c24fed9b 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -586,7 +586,6 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute patches = transformer_options.get("patches", {}) - transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) From 43071e3de3780f984a46549e90935a0bf405e9df Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:35:42 -0800 Subject: [PATCH 211/231] Make old scaled fp8 format use the new mixed quant ops system. (#11000) --- comfy/model_base.py | 14 +- comfy/model_detection.py | 33 +--- comfy/model_patcher.py | 20 +-- comfy/ops.py | 145 +++++++----------- comfy/quant_ops.py | 30 ++-- comfy/sd.py | 68 ++++++-- comfy/sd1_clip.py | 22 +-- comfy/supported_models_base.py | 3 +- comfy/text_encoders/cosmos.py | 12 +- comfy/text_encoders/flux.py | 12 +- comfy/text_encoders/genmo.py | 6 +- comfy/text_encoders/hidream.py | 10 +- comfy/text_encoders/hunyuan_image.py | 12 +- comfy/text_encoders/hunyuan_video.py | 23 ++- comfy/text_encoders/lumina2.py | 6 +- comfy/text_encoders/omnigen2.py | 6 +- comfy/text_encoders/ovis.py | 5 +- comfy/text_encoders/pixart_t5.py | 6 +- comfy/text_encoders/qwen_image.py | 6 +- comfy/text_encoders/sd3_clip.py | 19 +-- comfy/text_encoders/wan.py | 6 +- comfy/text_encoders/z_image.py | 5 +- comfy/utils.py | 66 ++++++++ .../comfy_quant/test_mixed_precision.py | 18 ++- 24 files changed, 278 insertions(+), 275 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285e..3cedd4f31 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -329,18 +329,6 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - - # Save mixed precision metadata - if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: - metadata = { - "format_version": "1.0", - "layers": self.model_config.layer_quant_config - } - unet_state_dict["_quantization_metadata"] = metadata - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0517e61..fd1907627 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,20 +6,6 @@ import math import logging import torch - -def detect_layer_quantization(metadata): - quant_key = "_quantization_metadata" - if metadata is not None and quant_key in metadata: - quant_metadata = metadata.pop(quant_key) - quant_metadata = json.loads(quant_metadata) - if isinstance(quant_metadata, dict) and "layers" in quant_metadata: - logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") - return quant_metadata["layers"] - else: - raise ValueError("Invalid quantization metadata format") - return None - - def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True - # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(metadata) - if layer_quant_config: - model_config.layer_quant_config = layer_quant_config - logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") return model_config diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3dcac3eef..215784874 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -126,27 +126,11 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight, inplace=False) - - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - - out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) #The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 diff --git a/comfy/ops.py b/comfy/ops.py index eae434e68..dc06709a1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +import json def run_every_op(): if torch.compiler.is_compiling(): @@ -422,22 +423,12 @@ def fp8_linear(self, input): if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) - - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - else: - scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! @@ -458,7 +449,7 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - if not self.training: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: try: out = fp8_linear(self, input) if out is not None: @@ -471,59 +462,6 @@ class fp8_ops(manual_cast): uncast_bias_weight(self, weight, bias, offload_stream) return x -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - - if weight.numel() < input.numel(): #TODO: optimize - x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -550,9 +488,9 @@ if CUBLAS_IS_AVAILABLE: from .quant_ops import QuantizedTensor, QUANT_ALGOS -def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): class MixedPrecisionOps(manual_cast): - _layer_quant_config = layer_quant_config + _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm @@ -595,27 +533,36 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful manually_loaded_keys = [weight_key] - if layer_name not in MixedPrecisionOps._layer_quant_config: + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - qconfig = QUANT_ALGOS[quant_format] + qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), + 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, 'block_size': qconfig.get("group_size", None), } - if layout_params['scale'] is not None: + + if scale is not None: manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), requires_grad=False ) @@ -624,7 +571,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -633,6 +580,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if key in missing_keys: missing_keys.remove(key) + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + return sd + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -648,9 +605,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -661,7 +617,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -670,17 +626,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful assert inplace_update is False # TODO: eventually remove the inplace_update stuff self.weight = torch.nn.Parameter(weight, requires_grad=False) + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + return MixedPrecisionOps -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular - if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860c..571d3f760 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor): def is_contiguous(self, *arg, **kwargs): return self._qdata.is_contiguous(*arg, **kwargs) + def storage(self): + return self._qdata.storage() + # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ============================================================================== @@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn): def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_dtype is not None and target_dtype != qt.dtype: - logging.warning( - f"QuantizedTensor: dtype conversion requested to {target_dtype}, " - f"but not supported for quantized tensors. Ignoring dtype." - ) - if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " @@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt @@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs): # Copy from another quantized tensor qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype else: # Copy from regular tensor - just copy raw data qt_dest._qdata.copy_(src) @@ -397,17 +398,20 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale is None: + if isinstance(scale, str) and scale == "recalculate": scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) - if inplace_ops: - tensor *= (1.0 / scale).to(tensor.dtype) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) else: - tensor = tensor * (1.0 / scale).to(tensor.dtype) + scale = torch.ones((), device=tensor.device, dtype=torch.float32) if stochastic_rounding > 0: tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) diff --git a/comfy/sd.py b/comfy/sd.py index 03bdb33d5..092715d79 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -968,10 +968,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - if metadata is not None: - quant_metadata = metadata.get("_quantization_metadata", None) - if quant_metadata is not None: - sd["_quantization_metadata"] = quant_metadata + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1088,7 +1086,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1112,7 +1110,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1141,7 +1139,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) @@ -1169,7 +1167,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1224,8 +1222,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: - if "_quantization_metadata" in c: - c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) @@ -1295,6 +1291,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1303,18 +1303,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1332,6 +1336,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) @@ -1394,6 +1419,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): if len(temp_sd) > 0: sd = temp_sd + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1424,7 +1452,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1432,12 +1460,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if model_config.layer_quant_config is not None: + if model_config.quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True @@ -1476,6 +1507,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 503a51843..962948dae 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None - quantization_metadata = model_options.get("quantization_metadata", None) + quant_config = model_options.get("quantization_metadata", None) if operations is None: - layer_quant_config = None - if quantization_metadata is not None: - layer_quant_config = json.loads(quantization_metadata).get("layers", None) - - if layer_quant_config is not None: - operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) - logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") + if quant_config is not None: + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: - # Fallback to scaled_fp8_ops for backward compatibility - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) - else: - operations = comfy.ops.manual_cast + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e4bd74514..9fd84d329 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,8 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None - layer_quant_config = None # Per-layer quantization configuration for mixed precision + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242..448381fa9 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ from transformers import T5TokenizerFast class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 99f4812bb..21d93d757 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ @@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra -def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Flux2TEModel_(Flux2TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata if pruned: model_options = model_options.copy() diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2..5daea8135 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784..600b34480 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..cd198036c 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 0110517bb..a9a6c525e 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast import torch import os import numbers - +import comfy.utils def llama_detect(state_dict, prefix=""): out = {} @@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype - - if "_quantization_metadata" in state_dict: - out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index fd986e2c1..7a6cfdab2 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": @@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4..50aa4121f 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 81c9bd51c..5754424d2 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel): return out, pooled, {} -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class OvisTEModel_(OvisTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 5f383de07..e5e5f18be 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index c0d32a6ef..5c14dec23 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db..8b153c72b 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -6,14 +6,15 @@ import torch import os import comfy.model_management import logging +import comfy.utils class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28..164a57edd 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index bb9273b20..19adde0b7 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class ZImageTEModel_(ZImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/utils.py b/comfy/utils.py index 37485e497..89846bc95 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,6 +29,7 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +import json MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes): else: output_tensors = combined_latent return output_tensors + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + + return state_dict, metadata diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 63361309f..3a54941e6 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -2,6 +2,7 @@ import unittest import torch import sys import os +import json # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -15,6 +16,7 @@ if not has_gpu(): from comfy import ops from comfy.quant_ops import QuantizedTensor +import comfy.utils class SimpleModel(torch.nn.Module): @@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) + with torch.inference_mode(): + output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) @@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False) From 6fd463aec958f02be79a264eafd6c8fe7e52762a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:33:16 -0800 Subject: [PATCH 212/231] Fix regression when text encoder loaded directly on GPU. (#11129) --- comfy/ops.py | 2 ++ comfy/sd.py | 44 ++++++++++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index dc06709a1..35237c9f7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -552,6 +552,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_scale_key = f"{prefix}weight_scale" scale = state_dict.pop(weight_scale_key, None) + if scale is not None: + scale = scale.to(device) layout_params = { 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, diff --git a/comfy/sd.py b/comfy/sd.py index 092715d79..c350322f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -98,7 +98,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: return params = target.params.copy() @@ -129,6 +129,27 @@ class CLIP: self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None + if len(state_dict) > 0: + if isinstance(state_dict, list): + for c in state_dict: + m, u = self.load_sd(c) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected: {}".format(u)) + else: + m, u = self.load_sd(state_dict, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None @@ -1225,14 +1246,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) - for c in clip_data: - m, u = clip.load_sd(c) - if len(m) > 0: - logging.warning("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected: {}".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) return clip def load_gligen(ckpt_path): @@ -1362,17 +1376,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") From 79d17ba2339aaf4f3422673b3dad24ba4dbd7552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:42:46 +0200 Subject: [PATCH 213/231] Context windows fixes and features (#10975) * Apply cond slice fix * Add FreeNoise * Update context_windows.py * Add option to retain condition by indexes for each window This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations. * Update context_windows.py * Allow splitting multiple conds into different windows * Add handling for audio_embed * whitespace * Allow freenoise to work on other dims, handle 4D batch timestep Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now. * Disable experimental options for now So that the Freenoise and bugfixes can be merged first --------- Co-authored-by: Jedrzej Kosinski Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 104 ++++++++++++++++++++++---- comfy_extras/nodes_context_windows.py | 22 +++++- 2 files changed, 108 insertions(+), 18 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 041f380f9..5c412d1c2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -51,26 +51,36 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): self.index_list = index_list self.context_length = len(index_list) self.dim = dim + self.total_frames = total_frames + self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) - def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: dim = self.dim if dim == 0 and full.shape[dim] == 1: return full - idx = [slice(None)] * dim + [self.index_list] - return full[idx].to(device) + idx = tuple([slice(None)] * dim + [self.index_list]) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: if dim is None: dim = self.dim - idx = [slice(None)] * dim + [self.index_list] + idx = tuple([slice(None)] * dim + [self.index_list]) full[idx] += to_add return full + def get_region_index(self, num_regions: int) -> int: + region_idx = int(self.center_ratio * num_regions) + return min(max(region_idx, 0), num_regions - 1) + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -94,7 +104,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC): self.closed_loop = closed_loop self.dim = dim self._step = 0 + self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] + self.split_conds_to_windows = split_conds_to_windows self.callbacks = {} def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: # for now, assume first dim is batch - should have stored on BaseModel in actual implementation if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True return False @@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC): return None # reuse or resize cond items to match context requirements resized_cond = [] + # if multiple conds, split based on primary region + if self.split_conds_to_windows and len(cond_in) > 1: + region = window.get_region_index(len(cond_in)) + logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() @@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC): # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, torch.Tensor): - if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # Handle audio_embed (temporal dim is 1) + elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + audio_cond = cond_value.cond + if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): - new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ + (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length @@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: raise Exception("No sample_sigmas matched current timestep; something went wrong.") @@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): @@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC): prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) # account for dims of tensors - idx_window = [slice(None)] * self.dim + [idx] - pos_window = [slice(None)] * self.dim + [pos] + idx_window = tuple([slice(None)] * self.dim + [idx]) + pos_window = tuple([slice(None)] * self.dim + [pos]) # apply new values conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight biases_final[i][idx] = bias_total + bias @@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher): ) +def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs): + model_options = extra_args.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is None: + raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") + if not handler.freenoise: + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + + +def create_sampler_sample_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, + "ContextWindows_sampler_sample", + _sampler_sample_wrapper + ) + + def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int): for i in range(len(window)): # 2) add end_delta to each val to slide windows to end window[i] = window[i] + end_delta + + +# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): + logging.info("Context windows: Applying FreeNoise") + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] + delta = context_length - context_overlap + + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: + break + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + + return noise diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 1c3d9e697..3799a9004 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode): io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode): context_overlap=context_overlap, context_stride=context_stride, closed_loop=closed_loop, - dim=dim) + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) return io.NodeOutput(model) class WanContextWindowsManualNode(ContextWindowsManualNode): @@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ] return schema @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): From 092ee8a5008c8d558b0a72cc7961a31d9cc5400b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:25:31 -0800 Subject: [PATCH 214/231] Fix some custom nodes. (#11134) --- comfy/supported_models_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 9fd84d329..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -17,6 +17,7 @@ """ import torch +import logging from . import model_base from . import utils from . import latent_formats @@ -117,3 +118,7 @@ class BASE: def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype + + def __getattr__(self, name): + logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) + return None From bed12674a1d2c4bfdfbdd098686390f807996c90 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 6 Dec 2025 08:45:38 +0900 Subject: [PATCH 215/231] docs: add ComfyUI-Manager documentation and update to v4.0.3b4 (#11133) - Add manager setup instructions and command line options to README - Document --enable-manager, --enable-manager-legacy-ui, and --disable-manager-ui flags - Bump comfyui_manager version from 4.0.3b3 to 4.0.3b4 --- README.md | 26 ++++++++++++++++++++++++++ manager_requirements.txt | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ed857df9f..bae955b1b 100644 --- a/README.md +++ b/README.md @@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step 1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) 2. Launch ComfyUI by running `python main.py` + +## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4) + +**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI. + +### Setup + +1. Install the manager dependencies: + ```bash + pip install -r manager_requirements.txt + ``` + +2. Enable the manager with the `--enable-manager` flag when running ComfyUI: + ```bash + python main.py --enable-manager + ``` + +### Command Line Options + +| Flag | Description | +|------|-------------| +| `--enable-manager` | Enable ComfyUI-Manager | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | + + # Running ```python main.py``` diff --git a/manager_requirements.txt b/manager_requirements.txt index 52cc5389c..b95cefb74 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b3 +comfyui_manager==4.0.3b4 From fd109325db7126f92c2dfb7e6b25310eded8c1f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 6 Dec 2025 05:20:22 +0200 Subject: [PATCH 216/231] Kandinsky5 model support (#10988) * Add Kandinsky5 model support lite and pro T2V tested to work * Update kandinsky5.py * Fix fp8 * Fix fp8_scaled text encoder * Add transformer_options for attention * Code cleanup, optimizations, use fp32 for all layers originally at fp32 * ImageToVideo -node * Fix I2V, add necessary latent post process nodes * Support text to image model * Support block replace patches (SLG mostly) * Support official LoRAs * Don't scale RoPE for lite model as that just doesn't work... * Update supported_models.py * Rever RoPE scaling to simpler one * Fix typo * Handle latent dim difference for image model in the VAE instead * Add node to use different prompts for clip_l and qwen25_7b * Reduce peak VRAM usage a bit * Further reduce peak VRAM consumption by chunking ffn * Update chunking * Update memory_usage_factor * Code cleanup, don't force the fp32 layers as it has minimal effect * Allow for stronger changes with first frames normalization Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available. * Add image model's own chat template, remove unused image2video template * Remove hard error in ReplaceVideoLatentFrames -node * Update kandinsky5.py * Update supported_models.py * Fix typos in prompt template They were now fixed in the original repository as well * Update ReplaceVideoLatentFrames Add tooltips Make source optional Better handle negative index * Rename NormalizeVideoLatentFrames -node For bit better clarity what it does * Fix NormalizeVideoLatentStart node out on non-op --- comfy/ldm/kandinsky5/model.py | 407 ++++++++++++++++++++++++++++++ comfy/lora.py | 7 + comfy/model_base.py | 47 ++++ comfy/model_detection.py | 18 ++ comfy/sd.py | 11 + comfy/supported_models.py | 56 +++- comfy/text_encoders/kandinsky5.py | 68 +++++ comfy_api/latest/_io.py | 2 + comfy_extras/nodes_kandinsky5.py | 136 ++++++++++ comfy_extras/nodes_latent.py | 39 ++- nodes.py | 3 +- 11 files changed, 791 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/kandinsky5/model.py create mode 100644 comfy/text_encoders/kandinsky5.py create mode 100644 comfy_extras/nodes_kandinsky5.py diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py new file mode 100644 index 000000000..a653e02fc --- /dev/null +++ b/comfy/ldm/kandinsky5/model.py @@ -0,0 +1,407 @@ +import torch +from torch import nn +import math + +import comfy.ldm.common_dit +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.layers import EmbedND + +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + +def apply_scale_shift_norm(norm, x, scale, shift): + return torch.addcmul(shift, norm(x), scale + 1.0) + +def apply_gate_sum(x, out, gate): + return torch.addcmul(x, gate, out) + +def get_shift_scale_gate(params): + shift, scale, gate = torch.chunk(params, 3, dim=-1) + return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) + +def get_freqs(dim, max_period=10000.0): + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.SiLU() + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, timestep, dtype): + args = torch.outer(timestep, self.freqs.to(device=timestep.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + x = x.movedim(1, -1) # B C T H W -> B T H W C + B, T, H, W, dim = x.shape + pt, ph, pw = self.patch_size + + x = x.view( + B, + T // pt, pt, + H // ph, ph, + W // pw, pw, + dim, + ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) + + return self.in_layer(x) + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params, operation_settings=None): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class SelfAttention(nn.Module): + def __init__(self, num_channels, head_dim, operation_settings=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + self.head_dim = head_dim + + operations = operation_settings.get("operations") + self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 + + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) + chunks = [] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + + +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) + return q, k, v + + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.GELU() + self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 + + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) + + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings) + operations = operation_settings.get("operations") + self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, visual_embed, time_embed): + B, T, H, W, _ = visual_embed.shape + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + scale = scale[:, None, None, None, :] + shift = shift[:, None, None, None, :] + visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift) + x = self.out_layer(visual_embed) + + out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2]) + x = x.view( + B, T, H, W, + out_dim, + self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings) + operations = operation_settings.get("operations") + + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, x, time_embed, freqs, transformer_options={}): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = get_shift_scale_gate(self_attn_params) + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = self.self_attention(out, freqs, transformer_options=transformer_options) + x = apply_gate_sum(x, out, gate) + + shift, scale, gate = get_shift_scale_gate(ff_params) + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = self.feed_forward(out) + x = apply_gate_sum(x, out, gate) + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings) + + operations = operation_settings.get("operations") + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) + # self attention + shift, scale, gate = get_shift_scale_gate(self_attn_params) + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # cross attention + shift, scale, gate = get_shift_scale_gate(cross_attn_params) + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # feed forward + shift, scale, gate = get_shift_scale_gate(ff_params) + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = self.feed_forward(visual_out) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + return visual_embed + + +class Kandinsky5(nn.Module): + def __init__( + self, + in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512, + model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32, + axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0), + dtype=None, device=None, operations=None, **kwargs + ): + super().__init__() + head_dim = sum(axes_dims) + self.rope_scale_factor = rope_scale_factor + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_embed_dim = visual_embed_dim + self.dtype = dtype + self.device = device + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings) + + self.text_transformer_blocks = nn.ModuleList( + [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)] + ) + + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings) + + self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) + + def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): + steps = seq_len if steps is None else steps + seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) + seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) + freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) + return freqs + + def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + else: + rope_scale_factor = self.rope_scale_factor + if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions + if h * w >= 14080: + rope_scale_factor = (1.0, 3.16, 3.16) + + t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0 + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) + return freqs + + def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + context = self.text_embeddings(context) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) + + for block in self.text_transformer_blocks: + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) + + visual_embed = self.visual_embeddings(x) + visual_shape = visual_embed.shape[:-1] + visual_embed = visual_embed.flatten(1, -2) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + + visual_embed = visual_embed.reshape(*visual_shape, -1) + return self.out_layer(visual_embed, time_embed) + + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + bs, c, t_len, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/lora.py b/comfy/lora.py index 3a9077869..e7202ce97 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if isinstance(model, comfy.model_base.Kandinsky5): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 3cedd4f31..0be006cc2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model +import comfy.ldm.kandinsky5.model import comfy.model_management import comfy.patcher_extension @@ -1630,3 +1631,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(False) return out + +class Kandinsky5(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + device = kwargs["device"] + image = torch.zeros_like(noise) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace)) + + return out + +class Kandinsky5Image(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + return None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fd1907627..30b33a486 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -611,6 +611,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 + dit_config = {} + model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["model_dim"] = model_dim + if model_dim in [4096, 2560]: # pro video and lite image + dit_config["axes_dims"] = (32, 48, 48) + if model_dim == 2560: # lite image + dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) + elif model_dim == 1792: # lite video + dit_config["axes_dims"] = (16, 24, 24) + dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["image_model"] = "kandinsky5" + dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] + dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') + dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index c350322f8..754b1703d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis +import comfy.text_encoders.kandinsky5 import comfy.model_patcher import comfy.lora @@ -766,6 +767,8 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = None do_tile = False + if self.latent_dim == 2 and samples_in.ndim == 5: + samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -983,6 +986,8 @@ class CLIPType(Enum): HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 OVIS = 21 + KANDINSKY5 = 22 + KANDINSKY5_IMAGE = 23 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1231,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO_15: clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer + elif clip_type == CLIPType.KANDINSKY5: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer + elif clip_type == CLIPType.KANDINSKY5_IMAGE: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index afd97160b..91cc4ef08 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image from . import supported_models_base @@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] +class Kandinsky5(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + } + + sampling_settings = { + "shift": 10.0, + } + + unet_extra_config = {} + latent_format = latent_formats.HunyuanVideo + + memory_usage_factor = 1.1 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +class Kandinsky5Image(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 64, + } + + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.1 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5Image(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py new file mode 100644 index 000000000..22f991c36 --- /dev/null +++ b/comfy/text_encoders/kandinsky5.py @@ -0,0 +1,68 @@ +from comfy import sd1_clip +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from .llama import Qwen25_7BVLI + + +class Kandinsky5Tokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + + return out + + +class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Kandinsky5TEModel(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) + + return cond, l_pooled, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + self.clip_l.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + self.clip_l.reset_clip_options() + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return super().load_sd(sd) + +def te(dtype_llama=None, llama_scaled_fp8=None): + class Kandinsky5TEModel_(Kandinsky5TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Kandinsky5TEModel_ diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 866c3e0eb..d7cbe68cf 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO): '''Used by WAN Camera.''' time_dim_concat: NotRequired[torch.Tensor] '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' CondList = list[tuple[torch.Tensor, PooledDict]] Type = CondList diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py new file mode 100644 index 000000000..9cb234be1 --- /dev/null +++ b/comfy_extras/nodes_kandinsky5.py @@ -0,0 +1,136 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class Kandinsky5ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent_out["samples"] = encoded + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentStart(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentStart", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"), + io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return io.NodeOutput(latent) + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :start_frame_count] + reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)] + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :start_frame_count] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) + + +class CLIPTextEncodeKandinsky5(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeKandinsky5", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class Kandinsky5Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Kandinsky5ImageToVideo, + NormalizeVideoLatentStart, + CLIPTextEncodeKandinsky5, + ] + +async def comfy_entrypoint() -> Kandinsky5Extension: + return Kandinsky5Extension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d2df07ff9..e439b18ef 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -4,7 +4,7 @@ import torch import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import logging def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode): return luminance * sharpened return io.NodeOutput(sharpen) +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), + io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, destination, index, source=None) -> io.NodeOutput: + if source is None: + return io.NodeOutput(destination) + dest_frames = destination["samples"].shape[2] + source_frames = source["samples"].shape[2] + if index < 0: + index = dest_frames + index + if index > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") + return io.NodeOutput(destination) + if index + source_frames > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") + return io.NodeOutput(destination) + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) class LatentExtension(ComfyExtension): @override @@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension): LatentApplyOperationCFG, LatentOperationTonemapReinhard, LatentOperationSharpen, + ReplaceVideoLatentFrames ] diff --git a/nodes.py b/nodes.py index 356aa63df..8d28a725d 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes(): "nodes_rope.py", "nodes_logic.py", "nodes_nop.py", + "nodes_kandinsky5.py", ] import_failed = [] From ae676ed105663bb225153c8bca406f00edf738b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:01:19 -0800 Subject: [PATCH 217/231] Fix regression. (#11137) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 91cc4ef08..383c82c3e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1529,6 +1529,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] From 117bf3f2bd9235cb5942a1de10a534c9869c7444 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 06:22:02 +0200 Subject: [PATCH 218/231] convert nodes_freelunch.py to the V3 schema (#10904) --- comfy_extras/nodes_freelunch.py | 89 +++++++++++++++++---------- comfy_extras/nodes_model_downscale.py | 5 -- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index e3ac58447..3429b731e 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -2,6 +2,8 @@ import torch import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO def Fourier_filter(x, threshold, scale): # FFT @@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale): return x_filtered.to(x.dtype) -class FreeU: +class FreeU(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -59,23 +66,31 @@ class FreeU: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -class FreeU_V2: + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -105,9 +120,19 @@ class FreeU_V2: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreeU": FreeU, - "FreeU_V2": FreeU_V2, -} + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f7ca9699d..dec2ae841 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode): return io.NodeOutput(m) -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "PatchModelAddDownscale": "", -} - class ModelDownscaleExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: From 913f86b72740f84f759786a698108840a09b6498 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 06:24:10 +0200 Subject: [PATCH 219/231] [V3] convert nodes_mask.py to V3 schema (#10669) * convert nodes_mask.py to V3 schema * set "Preview Mask" as display name for MaskPreview --- comfy_extras/nodes_mask.py | 508 +++++++++++++++++++------------------ 1 file changed, 263 insertions(+), 245 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index a5e405008..290e6f55e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -3,11 +3,10 @@ import scipy.ndimage import torch import comfy.utils import node_helpers -import folder_paths -import random +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI import nodes -from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) @@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou destination[..., top:bottom, left:right] = source_portion + destination_portion return destination -class LatentCompositeMasked: +class LatentCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("LATENT",), - "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("LATENT",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="LatentCompositeMasked", + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) - return (output,) + return IO.NodeOutput(output) -class ImageCompositeMasked: + composite = execute # TODO: remove + + +class ImageCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("IMAGE",), - "source": ("IMAGE",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="ImageCompositeMasked", + category="image", + inputs=[ + IO.Image.Input("destination"), + IO.Image.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) - return (output,) + return IO.NodeOutput(output) -class MaskToImage: + composite = execute # TODO: remove + + +class MaskToImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskToImage", + display_name="Convert Mask to Image", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) + return IO.NodeOutput(result) -class ImageToMask: + mask_to_image = execute # TODO: remove + + +class ImageToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageToMask", + display_name="Convert Image to Mask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): + @classmethod + def execute(cls, image, channel) -> IO.NodeOutput: channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] - return (mask,) + return IO.NodeOutput(mask) -class ImageColorToMask: + image_to_mask = execute # TODO: remove + + +class ImageColorToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageColorToMask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, color): + @classmethod + def execute(cls, image, color) -> IO.NodeOutput: temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 1.0, 0).float() - return (mask,) + return IO.NodeOutput(mask) -class SolidMask: + image_to_mask = execute # TODO: remove + + +class SolidMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="SolidMask", + category="mask", + inputs=[ + IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "solid" - - def solid(self, value, width, height): + @classmethod + def execute(cls, value, width, height) -> IO.NodeOutput: out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") - return (out,) + return IO.NodeOutput(out) -class InvertMask: + solid = execute # TODO: remove + + +class InvertMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="InvertMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "invert" - - def invert(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: out = 1.0 - mask - return (out,) + return IO.NodeOutput(out) -class CropMask: + invert = execute # TODO: remove + + +class CropMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="CropMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "crop" - - def crop(self, mask, x, y, width, height): + @classmethod + def execute(cls, mask, x, y, width, height) -> IO.NodeOutput: mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] - return (out,) + return IO.NodeOutput(out) -class MaskComposite: + crop = execute # TODO: remove + + +class MaskComposite(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "destination": ("MASK",), - "source": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskComposite", + category="mask", + inputs=[ + IO.Mask.Input("destination"), + IO.Mask.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "combine" - - def combine(self, destination, source, x, y, operation): + @classmethod + def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) @@ -267,28 +277,29 @@ class MaskComposite: output = torch.clamp(output, 0.0, 1.0) - return (output,) + return IO.NodeOutput(output) -class FeatherMask: + combine = execute # TODO: remove + + +class FeatherMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="FeatherMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "feather" - - def feather(self, mask, left, top, right, bottom): + @classmethod + def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput: output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) @@ -312,26 +323,28 @@ class FeatherMask: feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate - return (output,) + return IO.NodeOutput(output) -class GrowMask: + feather = execute # TODO: remove + + +class GrowMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "tapered_corners": ("BOOLEAN", {"default": True}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="GrowMask", + display_name="Grow Mask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("tapered_corners", default=True), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "expand_mask" - - def expand_mask(self, mask, expand, tapered_corners): + @classmethod + def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput: c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], @@ -347,69 +360,74 @@ class GrowMask: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.stack(out, dim=0),) + return IO.NodeOutput(torch.stack(out, dim=0)) -class ThresholdMask: + expand_mask = execute # TODO: remove + + +class ThresholdMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ThresholdMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, mask, value): + @classmethod + def execute(cls, mask, value) -> IO.NodeOutput: mask = (mask > value).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove + # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes -class MaskPreview(nodes.SaveImage): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 4 +class MaskPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MaskPreview", + display_name="Preview Mask", + category="mask", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + IO.Mask.Input("mask"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": {"mask": ("MASK",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - FUNCTION = "execute" - CATEGORY = "mask" - - def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewMask(mask)) -NODE_CLASS_MAPPINGS = { - "LatentCompositeMasked": LatentCompositeMasked, - "ImageCompositeMasked": ImageCompositeMasked, - "MaskToImage": MaskToImage, - "ImageToMask": ImageToMask, - "ImageColorToMask": ImageColorToMask, - "SolidMask": SolidMask, - "InvertMask": InvertMask, - "CropMask": CropMask, - "MaskComposite": MaskComposite, - "FeatherMask": FeatherMask, - "GrowMask": GrowMask, - "ThresholdMask": ThresholdMask, - "MaskPreview": MaskPreview -} +class MaskExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LatentCompositeMasked, + ImageCompositeMasked, + MaskToImage, + ImageToMask, + ImageColorToMask, + SolidMask, + InvertMask, + CropMask, + MaskComposite, + FeatherMask, + GrowMask, + ThresholdMask, + MaskPreview, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} + +async def comfy_entrypoint() -> MaskExtension: + return MaskExtension() From d7a0aef65033bf0fe56e521577a44fac1830b8b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:15:21 -0800 Subject: [PATCH 220/231] Set OCL_SET_SVM_SIZE on AMD. (#11139) --- cuda_malloc.py | 27 +++++++++++++++++---------- main.py | 3 +++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cuda_malloc.py b/cuda_malloc.py index 6520d5123..ee2bc4b69 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -63,18 +63,22 @@ def cuda_malloc_supported(): return True +version = "" + +try: + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ +except: + pass + if not args.cuda_malloc: try: - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc args.cuda_malloc = cuda_malloc_supported() @@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + +def get_torch_version_noimport(): + return str(version) diff --git a/main.py b/main.py index 0cd815d9e..0d02a087b 100644 --- a/main.py +++ b/main.py @@ -167,6 +167,9 @@ if __name__ == "__main__": os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") From 76f18e955dcbc88ed13d6802194fd897927f93e5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:28:08 +0200 Subject: [PATCH 221/231] marked all Pika API nodes a deprecated (#11146) --- comfy_api_nodes/nodes_pika.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 51148211b..acd88c391 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod From 7ac7d69d948e75c3a230d1262daab84d75aff895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:09:44 +0200 Subject: [PATCH 222/231] Fix EmptyAudio node input types (#11149) --- comfy_extras/nodes_audio.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 812301fb7..c7916443c 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode): step=0.01, tooltip="Duration of the empty audio clip in seconds", ), - IO.Float.Input( + IO.Int.Input( "sample_rate", default=44100, tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, ), - IO.Float.Input( + IO.Int.Input( "channels", default=2, min=1, From 50ca97e7765d9bbdbeec31a75f1f6c747d76948c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:36:20 -0800 Subject: [PATCH 223/231] Speed up lora compute and lower memory usage by doing it in fp16. (#11161) --- comfy/model_management.py | 14 ++++++++++++++ comfy/model_patcher.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..40717b1e4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1492,6 +1492,20 @@ def extended_fp16_support(): return True +LORA_COMPUTE_DTYPES = {} +def lora_compute_dtype(device): + dtype = LORA_COMPUTE_DTYPES.get(device, None) + if dtype is not None: + return dtype + + if should_use_fp16(device): + dtype = torch.float16 + else: + dtype = torch.float32 + + LORA_COMPUTE_DTYPES[device] = dtype + return dtype + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 215784874..4f076a6aa 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -614,10 +614,11 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: - temp_weight = weight.to(torch.float32, copy=True) + temp_weight = weight.to(temp_dtype, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) From 4086acf3c2f0ca3a8861b04f6179fa9f908e3e25 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 7 Dec 2025 09:42:09 +1000 Subject: [PATCH 224/231] Fix on-load VRAM OOM (#11144) slow down the CPU on model load to not run ahead. This fixes a VRAM on flux 2 load. I went to try and debug this with the memory trace pickles, which needs --disable-cuda-malloc which made the bug go away. So I tried this synchronize and it worked. The has some very complex interactions with the cuda malloc async and I dont have solid theory on this one yet. Still debugging but this gets us over the OOM for the moment. --- comfy/model_patcher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4f076a6aa..5b1ccb824 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -762,6 +762,8 @@ class ModelPatcher: key = "{}.{}".format(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) + if comfy.model_management.is_device_cuda(device_to): + torch.cuda.synchronize() logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True From 329480da5ab32949a411548f821ea60ab3e90dc7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:50:10 -0800 Subject: [PATCH 225/231] Fix qwen scaled fp8 not working with kandinsky. Make basic t2i wf work. (#11162) --- comfy/ldm/kandinsky5/model.py | 8 +++++++- comfy/text_encoders/kandinsky5.py | 12 ++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index a653e02fc..1509de2f8 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,6 +387,9 @@ class Kandinsky5(nn.Module): return self.out_layer(visual_embed, time_embed) def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + original_dims = x.ndim + if original_dims == 4: + x = x.unsqueeze(2) bs, c, t_len, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) @@ -397,7 +400,10 @@ class Kandinsky5(nn.Module): freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) - return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + if original_dims == 4: + out = out.squeeze(2) + return out def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index 22f991c36..be086458c 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -24,10 +24,10 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -56,12 +56,12 @@ class Kandinsky5TEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Kandinsky5TEModel_(Kandinsky5TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) From 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 7 Dec 2025 04:44:55 -0800 Subject: [PATCH 226/231] Properly load the newbie diffusion model. (#11172) There is still one of the text encoders missing and I didn't actually test it. --- comfy/ldm/lumina/model.py | 35 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 4 ++++ comfy/model_detection.py | 3 +++ 3 files changed, 42 insertions(+) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 6c24fed9b..c47df49ca 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -377,6 +377,7 @@ class NextDiT(nn.Module): z_image_modulation=False, time_scale=1.0, pad_tokens_multiple=None, + clip_text_dim=None, image_model=None, device=None, dtype=None, @@ -447,6 +448,31 @@ class NextDiT(nn.Module): ), ) + self.clip_text_pooled_proj = None + + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_text_embed = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024) + clip_text_dim, + min(dim, 1024), + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.layers = nn.ModuleList( [ JointTransformerBlock( @@ -585,6 +611,15 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0be006cc2..6b8a8454d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,6 +1110,10 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + clip_text_pooled = kwargs["pooled_output"] # Newbie + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30b33a486..74c547427 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 + ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) + if ctd_weight is not None: + dit_config["clip_text_dim"] = ctd_weight.shape[0] elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 From ec7f65187d85e22ea23345ce0d919e11768f255e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:21:41 +0200 Subject: [PATCH 227/231] chore(comfy_api): replace absolute imports with relative (#11145) --- comfy_api/latest/__init__.py | 8 ++++---- comfy_api/latest/_input/video_types.py | 2 +- comfy_api/latest/_input_impl/video_types.py | 4 ++-- comfy_api/latest/_io.py | 2 +- comfy_api/latest/_ui.py | 2 +- comfy_api/latest/_util/video_types.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 0fa01d1e7..35e1ac853 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput -from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents -from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input_impl import VideoFromFile, VideoFromComponents +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 @@ -80,7 +80,7 @@ class ComfyExtension(ABC): async def on_load(self) -> None: """ Called when an extension is loaded. - This should be used to initialize any global resources neeeded by the extension. + This should be used to initialize any global resources needed by the extension. """ @abstractmethod diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 87c81d73a..e634a0311 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -4,7 +4,7 @@ from fractions import Fraction from typing import Optional, Union, IO import io import av -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): """ diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a4cd3737d..ea35c6062 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -3,14 +3,14 @@ from av.container import InputContainer from av.subtitles.stream import SubtitleStream from fractions import Fraction from typing import Optional -from comfy_api.latest._input import AudioInput, VideoInput +from .._input import AudioInput, VideoInput import av import io import json import numpy as np import math import torch -from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents def container_to_output_format(container_format: str | None) -> str | None: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index d7cbe68cf..313a5af20 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from comfy_api.latest._resources import Resources, ResourcesLocal +from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker from ._util import MESH, VOXEL diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 5a75a3aae..2babe209a 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -22,7 +22,7 @@ import folder_paths # used for image preview from comfy.cli_args import args -from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput +from ._io import ComfyNode, FolderType, Image, _UIOutput class SavedResult(dict): diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index c3e3d8e3a..fd3b5a510 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from comfy_api.latest._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput class VideoCodec(str, Enum): AUTO = "auto" From 058f084371ef2ed0c456118dfdd3d0bfed17259b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 8 Dec 2025 17:22:51 +0800 Subject: [PATCH 228/231] Update workflow templates to v0.7.51 (#11150) * chore: update workflow templates to v0.7.50 * Update template to 0.7.51 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f98848e20..12a7c1089 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.33.10 -comfyui-workflow-templates==0.7.25 +comfyui-workflow-templates==0.7.51 comfyui-embedded-docs==0.3.1 torch torchsde From 85c4b4ae262c2de360891dd23c6504da2f5a6014 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:27:02 +0200 Subject: [PATCH 229/231] chore: replace imports of deprecated V1 classes (#11127) --- comfy_api_nodes/apis/veo_api.py | 2 +- comfy_api_nodes/nodes_gemini.py | 19 ++++++++++--------- comfy_api_nodes/nodes_ltxv.py | 17 +++++++---------- comfy_api_nodes/nodes_moonvalley.py | 19 ++++++++----------- comfy_api_nodes/nodes_runway.py | 29 +++++++++++++---------------- comfy_api_nodes/nodes_veo2.py | 12 +++++------- comfy_extras/nodes_video.py | 27 +++++++++++---------------- 7 files changed, 55 insertions(+), 70 deletions(-) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index 8328d1aa4..23ca725b7 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -85,7 +85,7 @@ class Response1(BaseModel): raiMediaFilteredReasons: Optional[list[str]] = Field( None, description='Reasons why media was filtered by responsible AI policies' ) - videos: Optional[list[Video]] = None + videos: Optional[list[Video]] = Field(None) class VeoGenVidPollResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 08f7b0f64..0b7422ef7 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -13,8 +13,7 @@ import torch from typing_extensions import override import folder_paths -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.gemini_api import ( GeminiContent, GeminiFileData, @@ -68,7 +67,7 @@ class GeminiImageModel(str, Enum): async def create_image_parts( cls: type[IO.ComfyNode], - images: torch.Tensor, + images: Input.Image, image_limit: int = 0, ) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] @@ -154,8 +153,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: - image_tensors: list[torch.Tensor] = [] +def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: + image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: image_data = base64.b64decode(part.inlineData.data) @@ -293,7 +292,9 @@ class GeminiNode(IO.ComfyNode): def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: """Convert video input to Gemini API compatible parts.""" - base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) return [ GeminiPart( inlineData=GeminiInlineData( @@ -343,7 +344,7 @@ class GeminiNode(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, audio: Input.Audio | None = None, video: Input.Video | None = None, files: list[GeminiPart] | None = None, @@ -542,7 +543,7 @@ class GeminiImage(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", response_modalities: str = "IMAGE+TEXT", @@ -662,7 +663,7 @@ class GeminiImage2(IO.ComfyNode): aspect_ratio: str, resolution: str, response_modalities: str, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0b757a62b..7e61560dc 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -1,12 +1,9 @@ from io import BytesIO -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.util import ( ApiEndpoint, get_number_of_images, @@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel): model: str = Field(...) duration: int = Field(...) resolution: str = Field(...) - fps: Optional[int] = Field(25) - generate_audio: Optional[bool] = Field(True) - image_uri: Optional[str] = Field(None) + fps: int | None = Field(25) + generate_audio: bool | None = Field(True) + image_uri: str | None = Field(None) class TextToVideoNode(IO.ComfyNode): @@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class ImageToVideoNode(IO.ComfyNode): @@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, model: str, prompt: str, duration: int, @@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class LtxvApiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7c31d95b3..2771e4790 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,11 +1,8 @@ import logging -from typing import Optional -import torch from typing_extensions import override -from comfy_api.input import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, @@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None: raise RuntimeError(error_msg) -def validate_video_to_video_input(video: VideoInput) -> VideoInput: +def validate_video_to_video_input(video: Input.Video) -> Input.Video: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput: return _validate_and_trim_duration(video) -def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: +def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: """Extracts video dimensions with error handling.""" try: return video.get_dimensions() @@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None: raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def _validate_and_trim_duration(video: VideoInput) -> VideoInput: +def _validate_and_trim_duration(video: Input.Video) -> Input.Video: """Validates video duration and trims to 5 seconds if needed.""" duration = video.get_duration() _validate_minimum_duration(duration) @@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None: raise ValueError("Input video must be at least 5 seconds long.") -def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: +def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: """Trims video to 5 seconds if longer.""" if duration > 5: return trim_video(video, 5) @@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str, resolution: str, @@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): prompt: str, negative_prompt: str, seed: int, - video: Optional[VideoInput] = None, + video: Input.Video | None = None, control_type: str = "Motion Transfer", - motion_intensity: Optional[int] = 100, + motion_intensity: int | None = 100, steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 2fdafbbfe..3c55039c9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,12 +11,11 @@ User Guides: """ -from typing import Union, Optional -from typing_extensions import override from enum import Enum -import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, @@ -44,8 +43,6 @@ from comfy_api_nodes.util import ( sync_op, poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum): field_1280_768 = "1280:768" -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the video URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N def extract_progress_from_task_status( response: TaskStatusResponse, -) -> Union[float, None]: +) -> float | None: if hasattr(response, "progress") and response.progress is not None: return response.progress * 100 return None -def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" return await poll_op( @@ -119,8 +116,8 @@ async def get_response( async def generate_video( cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - estimated_duration: Optional[int] = None, -) -> VideoFromFile: + estimated_duration: int | None = None, +) -> InputImpl.VideoFromFile: initial_response = await sync_op( cls, endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), @@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, - end_frame: torch.Tensor, + start_frame: Input.Image, + end_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode): cls, prompt: str, ratio: str, - reference_image: Optional[torch.Tensor] = None, + reference_image: Input.Image | None = None, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index a54dc13ab..e165b8380 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,11 +1,9 @@ import base64 from io import BytesIO -import torch from typing_extensions import override -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, @@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Check if video is provided as base64 or URL if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if hasattr(video, "gcsUri") and video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) @@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): aspect_ratio: str, duration: int, seed: int, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, model: str, generate_audio: bool, ): @@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): if response.videos: video = response.videos[0] if video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) raise Exception("Video returned but no data or URL was provided") diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 6cf6e39bf..c609e03da 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -8,10 +8,7 @@ import json from typing import Optional from typing_extensions import override from fractions import Fraction -from comfy_api.input import AudioInput, ImageInput, VideoInput -from comfy_api.input_impl import VideoFromComponents, VideoFromFile -from comfy_api.util import VideoCodec, VideoComponents, VideoContainer -from comfy_api.latest import ComfyExtension, io, ui +from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args class SaveWEBM(io.ComfyNode): @@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode): io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode): inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), - io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), - io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput: + def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode): metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=VideoContainer(format), + format=Types.VideoContainer(format), codec=codec, metadata=saved_metadata ) @@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: return io.NodeOutput( - VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) ) class GetVideoComponents(io.ComfyNode): @@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput) -> io.NodeOutput: + def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) + class LoadVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode): @classmethod def execute(cls, file) -> io.NodeOutput: video_path = folder_paths.get_annotated_filepath(file) - return io.NodeOutput(VideoFromFile(video_path)) + return io.NodeOutput(InputImpl.VideoFromFile(video_path)) @classmethod def fingerprint_inputs(s, file): From c3c6313fc7b24a5811efde7cfe10b7cbbea52663 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:28:17 +0200 Subject: [PATCH 230/231] Added "system_prompt" input to Gemini nodes (#11177) --- comfy_api_nodes/apis/gemini_api.py | 10 +----- comfy_api_nodes/nodes_gemini.py | 52 ++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index a380ecc86..f8edc38c9 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel): description="A list of ordered parts that make up a single message. " "Different parts may have different IANA MIME types.", ) - role: GeminiRole = Field( - ..., - description="The identity of the entity that creates the message. " - "The following values are supported: " - "user: This indicates that the message is sent by a real person, typically a user-generated message. " - "model: This indicates that the message is generated by the model. " - "The model value is used to insert messages from model into the conversation during multi-turn conversations. " - "For non-multi-turn conversations, this field can be left blank or unset.", - ) + role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.") class GeminiFunctionDeclaration(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 0b7422ef7..ad0f4b4d1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -26,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import ( GeminiMimeType, GeminiPart, GeminiRole, + GeminiSystemInstructionContent, + GeminiTextPart, Modality, ) from comfy_api_nodes.util import ( @@ -42,6 +44,14 @@ from comfy_api_nodes.util import ( GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_IMAGE_SYS_PROMPT = ( + "You are an expert image-generation engine. You must ALWAYS produce an image.\n" + "Interpret all user input—regardless of " + "format, intent, or abstraction—as literal visual directives for image composition.\n" + "If a prompt is conversational or lacks specific visual details, " + "you must creatively invent a concrete visual scenario that depicts the concept.\n" + "Prioritize generating the visual representation above any text, formatting, or conversational requests." +) class GeminiModel(str, Enum): @@ -276,6 +286,13 @@ class GeminiNode(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.String.Output(), @@ -348,6 +365,7 @@ class GeminiNode(IO.ComfyNode): audio: Input.Audio | None = None, video: Input.Video | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -364,7 +382,10 @@ class GeminiNode(IO.ComfyNode): if files is not None: parts.extend(files) - # Create response + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -374,7 +395,8 @@ class GeminiNode(IO.ComfyNode): role=GeminiRole.user, parts=parts, ) - ] + ], + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -524,6 +546,13 @@ class GeminiImage(IO.ComfyNode): "'IMAGE+TEXT' to return both the generated image and a text response.", optional=True, ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -547,6 +576,7 @@ class GeminiImage(IO.ComfyNode): files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", response_modalities: str = "IMAGE+TEXT", + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) parts: list[GeminiPart] = [GeminiPart(text=prompt)] @@ -560,6 +590,10 @@ class GeminiImage(IO.ComfyNode): if files is not None: parts.extend(files) + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -571,6 +605,7 @@ class GeminiImage(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -641,6 +676,13 @@ class GeminiImage2(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -665,6 +707,7 @@ class GeminiImage2(IO.ComfyNode): response_modalities: str, images: Input.Image | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -680,6 +723,10 @@ class GeminiImage2(IO.ComfyNode): if aspect_ratio != "auto": image_config.aspectRatio = aspect_ratio + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -691,6 +738,7 @@ class GeminiImage2(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, From fd271dedfde6e192a1f1a025521070876e89e04a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:33:46 +0200 Subject: [PATCH 231/231] [API Nodes] add support for seedance-1-0-pro-fast model (#10947) * feat(api-nodes): add support for seedance-1-0-pro-fast model * feat(api-nodes): add support for seedream-4.5 model --- comfy_api_nodes/apis/bytedance_api.py | 144 +++++++++++++++ comfy_api_nodes/nodes_bytedance.py | 255 ++++++-------------------- 2 files changed, 196 insertions(+), 203 deletions(-) create mode 100644 comfy_api_nodes/apis/bytedance_api.py diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py new file mode 100644 index 000000000..77cd76f9b --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -0,0 +1,144 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + size: str | None = Field(None) + seed: int | None = Field(0, ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: str | None = Field("adaptive") + seed: int | None = Field(..., ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str = Field("url") + image: list[str] | None = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: TaskStatusError | None = Field(None) + content: TaskStatusResult | None = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, + "seedance-1-0-pro-fast-251015": { + "480p": 50, + "720p": 65, + "1080p": 100, + }, +} diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index caced471e..57c0218d0 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,13 +1,27 @@ import logging import math -from enum import Enum -from typing import Literal, Optional, Union import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_api import ( + RECOMMENDED_PRESETS, + RECOMMENDED_PRESETS_SEEDREAM_4, + VIDEO_TASKS_EXECUTION_TIME, + Image2ImageTaskCreationRequest, + Image2VideoTaskCreationRequest, + ImageTaskCreationResponse, + Seedream4Options, + Seedream4TaskCreationRequest, + TaskCreationResponse, + TaskImageContent, + TaskImageContentUrl, + TaskStatusResponse, + TaskTextContent, + Text2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} -class Text2ImageModelName(str, Enum): - seedream_3 = "seedream-3-0-t2i-250415" - - -class Image2ImageModelName(str, Enum): - seededit_3 = "seededit-3-0-i2i-250628" - - -class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-t2v-250428" - - -class Image2VideoModelName(str, Enum): - """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-i2v-250428" - - -class Text2ImageTaskCreationRequest(BaseModel): - model: Text2ImageModelName = Text2ImageModelName.seedream_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - size: Optional[str] = Field(None) - seed: Optional[int] = Field(0, ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Image2ImageTaskCreationRequest(BaseModel): - model: Image2ImageModelName = Image2ImageModelName.seededit_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: Optional[str] = Field("adaptive") - seed: Optional[int] = Field(..., ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Seedream4Options(BaseModel): - max_images: int = Field(15) - - -class Seedream4TaskCreationRequest(BaseModel): - model: str = Field("seedream-4-0-250828") - prompt: str = Field(...) - response_format: str = Field("url") - image: Optional[list[str]] = Field(None, description="Image URLs") - size: str = Field(...) - seed: int = Field(..., ge=0, le=2147483647) - sequential_image_generation: str = Field("disabled") - sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) - - -class ImageTaskCreationResponse(BaseModel): - model: str = Field(...) - created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") - data: list = Field([], description="Contains information about the generated image(s).") - error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") - - -class TaskTextContent(BaseModel): - type: str = Field("text") - text: str = Field(...) - - -class TaskImageContentUrl(BaseModel): - url: str = Field(...) - - -class TaskImageContent(BaseModel): - type: str = Field("image_url") - image_url: TaskImageContentUrl = Field(...) - role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) - - -class Text2VideoTaskCreationRequest(BaseModel): - model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro - content: list[TaskTextContent] = Field(..., min_length=1) - - -class Image2VideoTaskCreationRequest(BaseModel): - model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro - content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) - - -class TaskCreationResponse(BaseModel): - id: str = Field(...) - - -class TaskStatusError(BaseModel): - code: str = Field(...) - message: str = Field(...) - - -class TaskStatusResult(BaseModel): - video_url: str = Field(...) - - -class TaskStatusResponse(BaseModel): - id: str = Field(...) - model: str = Field(...) - status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) - error: Optional[TaskStatusError] = Field(None) - content: Optional[TaskStatusResult] = Field(None) - - -RECOMMENDED_PRESETS = [ - ("1024x1024 (1:1)", 1024, 1024), - ("864x1152 (3:4)", 864, 1152), - ("1152x864 (4:3)", 1152, 864), - ("1280x720 (16:9)", 1280, 720), - ("720x1280 (9:16)", 720, 1280), - ("832x1248 (2:3)", 832, 1248), - ("1248x832 (3:2)", 1248, 832), - ("1512x648 (21:9)", 1512, 648), - ("2048x2048 (1:1)", 2048, 2048), - ("Custom", None, None), -] - -RECOMMENDED_PRESETS_SEEDREAM_4 = [ - ("2048x2048 (1:1)", 2048, 2048), - ("2304x1728 (4:3)", 2304, 1728), - ("1728x2304 (3:4)", 1728, 2304), - ("2560x1440 (16:9)", 2560, 1440), - ("1440x2560 (9:16)", 1440, 2560), - ("2496x1664 (3:2)", 2496, 1664), - ("1664x2496 (2:3)", 1664, 2496), - ("3024x1296 (21:9)", 3024, 1296), - ("4096x4096 (1:1)", 4096, 4096), - ("Custom", None, None), -] - -# The time in this dictionary are given for 10 seconds duration. -VIDEO_TASKS_EXECUTION_TIME = { - "seedance-1-0-lite-t2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-lite-i2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-pro-250528": { - "480p": 70, - "720p": 85, - "1080p": 115, - }, -} - - def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" @@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: return response.data[0]["url"] -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: - """Returns the video URL from the task status response if it exists.""" - if hasattr(response, "content") and response.content: - return response.content.video_url - return None - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode): category="api node/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Text2ImageModelName, - default=Text2ImageModelName.seedream_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), IO.String.Input( "prompt", multiline=True, @@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Image2ImageModelName, - default=Image2ImageModelName.seededit_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), IO.Image.Input( "image", tooltip="The base image to edit", @@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, seed: int, guidance_scale: float, @@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedream-4-0-250828"], + options=["seedream-4-5-251128", "seedream-4-0-250828"], tooltip="Model name", ), IO.String.Input( @@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor = None, + image: Input.Image | None = None, size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], width: int = 2048, height: int = 2048, @@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode): raise ValueError( f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) + out_num_pixels = w * h + mp_provided = out_num_pixels / 1_000_000.0 + if "seedream-4-5" in model and out_num_pixels < 3686400: + raise ValueError( + f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"but {mp_provided:.2f}MP provided." + ) + if "seedream-4-0" in model and out_num_pixels < 921600: + raise ValueError( + f"Minimum image resolution that the selected model can generate is 0.92MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") @@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Text2VideoModelName, - default=Text2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Image2VideoModelName, - default=Image2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor, + image: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[model.value for model in Image2VideoModelName], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): cls, model: str, prompt: str, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): cls, model: str, prompt: str, - images: torch.Tensor, + images: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): async def process_video_task( cls: type[IO.ComfyNode], - payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - estimated_duration: Optional[int], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, ) -> IO.NodeOutput: initial_response = await sync_op( cls, @@ -1085,7 +934,7 @@ async def process_video_task( estimated_duration=estimated_duration, response_model=TaskStatusResponse, ) - return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) def raise_if_text_params(prompt: str, text_params: list[str]) -> None: