diff --git a/execution.py b/execution.py index 4b4f63c80..b536aa50d 100644 --- a/execution.py +++ b/execution.py @@ -46,6 +46,167 @@ class ExecutionResult(Enum): class DuplicateNodeError(Exception): pass + +# ====================================================================================== +# ADDED: Node grouping helpers for "input-type locality" execution ordering +# -------------------------------------------------------------------------------------- +# We cluster ready-to-run nodes by a signature derived from: +# - Declared INPUT_TYPES (required/optional socket types) +# - Upstream linked RETURN_TYPES (when available from prompt links) +# +# This is a SCHEDULING optimization only: +# - It must not change correctness or dependency satisfaction. +# - It only reorders nodes that ExecutionList already deems ready/executable. +# - It is stable to avoid churn and to preserve deterministic behavior. +# +# IMPORTANT: ExecutionList is imported from comfy_execution.graph; we avoid invasive +# changes by using a small subclass + defensive introspection of its internal queues. +# ====================================================================================== + +def _safe_stringify_type(t): + try: + return str(t) + except Exception: + return repr(t) + +def _node_input_signature_from_prompt(prompt: dict, node_id: str): + """ + Build a stable, hashable signature representing a node's *input requirements*. + + Includes: + - Declared input socket types via INPUT_TYPES() (required + optional) + - Linked upstream output RETURN_TYPES, when input is a link + + This signature is used ONLY for grouping/sorting ready nodes. + """ + node = prompt.get(node_id) + if node is None: + return ("", node_id) + + class_type = node.get("class_type") + class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type) + if class_def is None: + return ("", class_type, node_id) + + sig = [] + + # Declared socket types (required/optional) + try: + input_types = class_def.INPUT_TYPES() + except Exception: + input_types = {} + + for cat in ("required", "optional"): + cat_dict = input_types.get(cat, {}) + if isinstance(cat_dict, dict): + # Sort keys for stability + for k in sorted(cat_dict.keys()): + v = cat_dict[k] + sig.append(("decl", cat, k, _safe_stringify_type(v))) + + # Linked upstream return types (helps cluster by latent/model flows) + inputs = node.get("inputs", {}) or {} + if isinstance(inputs, dict): + for k in sorted(inputs.keys()): + v = inputs[k] + if is_link(v) and isinstance(v, (list, tuple)) and len(v) == 2: + src_id, out_idx = v[0], v[1] + src_node = prompt.get(src_id) + if src_node is None: + sig.append(("link", k, "")) + continue + src_class_type = src_node.get("class_type") + src_class_def = nodes.NODE_CLASS_MAPPINGS.get(src_class_type) + if src_class_def is None: + sig.append(("link", k, "", src_class_type)) + continue + ret_types = getattr(src_class_def, "RETURN_TYPES", ()) + try: + if isinstance(out_idx, int) and out_idx < len(ret_types): + sig.append(("link", k, _safe_stringify_type(ret_types[out_idx]))) + else: + sig.append(("link", k, "", _safe_stringify_type(out_idx))) + except Exception: + sig.append(("link", k, "")) + + return tuple(sig) + +def _try_group_sort_execution_list_ready_nodes(execution_list: ExecutionList, prompt: dict): + """ + Attempt to reorder the ExecutionList's *ready* nodes in-place, grouping by input signature. + + This is intentionally defensive because ExecutionList is external; we only touch + well-known/observed internal attributes when they match expected shapes. + + Supported patterns (best-effort): + - execution_list.nodes_to_execute : list[node_id, ...] + - execution_list._nodes_to_execute : list[node_id, ...] (fallback) + + We DO NOT rewrite heaps/tuples with priority keys, because that risks breaking invariants. + If the internal structure is not a simple list of node_ids, we do nothing. + """ + # Candidate attribute names that (in some ComfyUI revisions) hold ready-to-run node IDs + candidates = ("nodes_to_execute", "_nodes_to_execute") + for attr in candidates: + if not hasattr(execution_list, attr): + continue + value = getattr(execution_list, attr) + + # Only operate on a plain list of node ids (strings/ints) + if isinstance(value, list) and all(isinstance(x, (str, int)) for x in value): + # Stable grouping sort: + # primary: signature (to cluster similar input requirements) + # secondary: original order (stability) + # NOTE: include length of signature in key to reduce expensive stringification + indexed = list(enumerate(value)) + indexed.sort( + key=lambda it: ( + # signature key + _node_input_signature_from_prompt(prompt, str(it[1])), + # keep stable within same signature + it[0], + ) + ) + new_list = [node_id for _, node_id in indexed] + setattr(execution_list, attr, new_list) + return True + + return False + + +class GroupedExecutionList(ExecutionList): + """ + ADDED: Thin wrapper around ExecutionList that reorders *ready* nodes before staging + to improve model/tensor locality (reduce VRAM/RAM chatter). + + This does not change dependency logic; it only reorders nodes that are already ready. + """ + + def _apply_group_sort_if_possible(self): + try: + # dynprompt.original_prompt is the canonical prompt graph dict + prompt = getattr(self, "dynprompt", None) + prompt_dict = None + if prompt is not None: + prompt_dict = getattr(prompt, "original_prompt", None) + if isinstance(prompt_dict, dict): + _try_group_sort_execution_list_ready_nodes(self, prompt_dict) + except Exception: + # Must never break execution + pass + + # NOTE: stage_node_execution is awaited in the caller in this file, so we keep it async-compatible. + async def stage_node_execution(self): + # Group-sort the ready list *before* choosing next node + self._apply_group_sort_if_possible() + return await super().stage_node_execution() + + def add_node(self, node_id): + # Keep original behavior, then regroup for future staging + super().add_node(node_id) + self._apply_group_sort_if_possible() + + class IsChangedCache: def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): self.prompt_id = prompt_id @@ -707,7 +868,13 @@ class PromptExecutor: pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + + # ================================================================================== + # CHANGED: Use GroupedExecutionList to group ready-to-run nodes by input signature. + # This reduces VRAM/RAM chatter when workflows reuse the same models/tensor types. + # ================================================================================== + execution_list = GroupedExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id)