diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..c8cc72c95 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,12 @@ -from __future__ import annotations -from typing import Type, Literal +# graph.py — grouped/batched scheduler on top of the updated ExecutionList +# Implements model-class batching to reduce device/context swaps while preserving +# the new execution_cache behavior added upstream. + +from __future__ import annotations +from typing import Type, Literal, Optional + +import os import nodes import asyncio import inspect @@ -10,15 +16,19 @@ from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputType # NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests ExecutionBlocker = ExecutionBlocker + class DependencyCycleError(Exception): pass + class NodeInputError(Exception): pass + class NodeNotFoundError(Exception): pass + class DynamicPrompt: def __init__(self, original_prompt): # The original prompt provided by the user @@ -62,6 +72,7 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt + def get_input_info( class_def: Type[ComfyNodeABC], input_name: str, @@ -99,12 +110,13 @@ def get_input_info( extra_info = {} return input_type, input_category, extra_info + class TopologicalSort: def __init__(self, dynprompt): self.dynprompt = dynprompt self.pendingNodes = {} - self.blockCount = {} # Number of nodes this node is directly blocked by - self.blocking = {} # Which nodes are blocked by this node + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node self.externalBlocks = 0 self.unblockedEvent = asyncio.Event() @@ -165,6 +177,7 @@ class TopologicalSort: assert node_id in self.blockCount, "Can't add external block to a node that isn't pending" self.externalBlocks += 1 self.blockCount[node_id] += 1 + def unblock(): self.externalBlocks -= 1 self.blockCount[node_id] -= 1 @@ -186,36 +199,49 @@ class TopologicalSort: def is_empty(self): return len(self.pendingNodes) == 0 + class ExecutionList(TopologicalSort): """ - ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, - it can still be returned to the graph after having further dependencies added. + ExecutionList implements a topological dissolve of the graph with batching. + After a node is staged for execution, it can still be returned to the graph + after having further dependencies added. + + Batching: we favor running nodes of the same class_type back-to-back + to reduce device/context thrash (e.g., model swaps). Within a batch we still + apply UX-friendly priorities (output/async early, VAEDecode→preview, etc.). """ + def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache - self.staged_node_id = None + self.staged_node_id: Optional[str] = None + + # Upstream execution cache (kept intact) self.execution_cache = {} self.execution_cache_listeners = {} + # Batching state + self._current_group_class: Optional[str] = None + + # ----------------------------- cache --------------------------------- def is_cached(self, node_id): return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: return None - #Write back to the main cache on touch. + # Write back to the main cache on touch. self.output_cache.set(from_node_id, value) return value @@ -229,16 +255,93 @@ class ExecutionList(TopologicalSort): super().add_strong_link(from_node_id, from_socket, to_node_id) self.cache_link(from_node_id, to_node_id) + # --------------------------- group utils ------------------------------ + def _pick_largest_group(self, node_list): + """Return the class_type with the most representatives in node_list. + Ties are resolved deterministically by class name.""" + counts = {} + for nid in node_list: + ctype = self.dynprompt.get_node(nid)["class_type"] + counts[ctype] = counts.get(ctype, 0) + 1 + # max by (count, class_name) for deterministic tie-break + return max(counts.items(), key=lambda kv: (kv[1], kv[0]))[0] + + def _filter_by_group(self, node_list, group_cls): + """Keep only nodes that belong to the given class.""" + return [nid for nid in node_list if self.dynprompt.get_node(nid)["class_type"] == group_cls] + + # ------------------------- node classification ------------------------ + def _is_output(self, node_id): + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return getattr(class_def, 'OUTPUT_NODE', False) is True + + def _is_async(self, node_id): + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) + + # ------------------------- UX within a batch -------------------------- + def _pick_in_batch_with_ux(self, candidates): + """ + Original UX heuristics, but applied *within* the current batch. + """ + # 1) Output nodes ASAP + for nid in candidates: + if self._is_output(nid): + return nid + # 1b) Async nodes early to overlap + for nid in candidates: + if self._is_async(nid): + return nid + # 2) decoder-before-preview pattern (within the batch) + for nid in candidates: + for blocked in self.blocking[nid]: + if self._is_output(blocked): + return nid + # 3) VAELoader -> VAEDecode -> preview (within the batch) + for nid in candidates: + for blocked in self.blocking[nid]: + for blocked2 in self.blocking[blocked]: + if self._is_output(blocked2): + return nid + # 4) Otherwise, first candidate + return candidates[0] + + # ------------------------- batch-aware picking ------------------------ + def ux_friendly_pick_node(self, available): + """ + Choose which ready node to execute next, honoring the current batch. + When the current batch runs dry, switch to the largest ready group. + """ + + # Ensure current batch is still present; otherwise pick a new largest group. + has_current = ( + self._current_group_class is not None and + any(self.dynprompt.get_node(nid)["class_type"] == self._current_group_class for nid in available) + ) + if not has_current: + new_group = self._pick_largest_group(available) + self._current_group_class = new_group + + # Restrict to nodes of the current batch + candidates = self._filter_by_group(available, self._current_group_class) + return self._pick_in_batch_with_ux(candidates) + + # --------------------------- staging / run ---------------------------- async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): return None, None, None + available = self.get_ready_nodes() + + # If nothing ready but there are external blockers, wait for unblocks. while len(available) == 0 and self.externalBlocks > 0: - # Wait for an external block to be released await self.unblockedEvent.wait() self.unblockedEvent.clear() available = self.get_ready_nodes() + if len(available) == 0: cycled_nodes = self.get_nodes_in_cycle() # Because cycles composed entirely of static nodes are caught during initial validation, @@ -259,64 +362,30 @@ class ExecutionList(TopologicalSort): } return None, error_details, ex + # Batch-aware pick self.staged_node_id = self.ux_friendly_pick_node(available) return self.staged_node_id, None, None - def ux_friendly_pick_node(self, node_list): - # If an output node is available, do that first. - # Technically this has no effect on the overall length of execution, but it feels better as a user - # for a PreviewImage to display a result as soon as it can - # Some other heuristics could probably be used here to improve the UX further. - def is_output(node_id): - class_type = self.dynprompt.get_node(node_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - return True - return False - - # If an available node is async, do that first. - # This will execute the asynchronous function earlier, reducing the overall time. - def is_async(node_id): - class_type = self.dynprompt.get_node(node_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) - - for node_id in node_list: - if is_output(node_id) or is_async(node_id): - return node_id - - #This should handle the VAEDecode -> preview case - for node_id in node_list: - for blocked_node_id in self.blocking[node_id]: - if is_output(blocked_node_id): - return node_id - - #This should handle the VAELoader -> VAEDecode -> preview case - for node_id in node_list: - for blocked_node_id in self.blocking[node_id]: - for blocked_node_id1 in self.blocking[blocked_node_id]: - if is_output(blocked_node_id1): - return node_id - - #TODO: this function should be improved - return node_list[0] - def unstage_node_execution(self): + # If a node execution resolves to PENDING, return it to the pool + # but keep the current batch so we continue batching next time. assert self.staged_node_id is not None self.staged_node_id = None def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + # Maintain current batch; it will switch automatically when empty. self.execution_cache.pop(node_id, None) self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None + # ------------------------- cycle detection ---------------------------- def get_nodes_in_cycle(self): # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. # We're skipping some of the performance optimizations from the original TopologicalSort to keep # the code simple (and because having a cycle in the first place is a catastrophic error) - blocked_by = { node_id: {} for node_id in self.pendingNodes } + blocked_by = {node_id: {} for node_id in self.pendingNodes} for from_node_id in self.blocking: for to_node_id in self.blocking[from_node_id]: if True in self.blocking[from_node_id][to_node_id].values():