From 14184a0918786c6190310c420f0c826c3845404a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 12 Feb 2026 22:04:52 -0800 Subject: [PATCH] Add apply_replacements to NodeReplaceManager to apply registered node replacements when executing /prompt endpoint --- app/node_replace_manager.py | 66 ++++++++++++++++++++++++++++++++++++- server.py | 2 ++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index 92a0949ab..f6cdeceb4 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -2,10 +2,26 @@ from __future__ import annotations from aiohttp import web -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: from comfy_api.latest._node_replace import NodeReplace +from nodes import NODE_CLASS_MAPPINGS + +class NodeStruct(TypedDict): + inputs: dict[str, str | int | float | tuple[str, int]] + class_type: str + _meta: dict[str, str] + +def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: + new_node_struct = node_struct.copy() + if empty_inputs: + new_node_struct["inputs"] = {} + else: + new_node_struct["inputs"] = node_struct["inputs"].copy() + new_node_struct["_meta"] = node_struct["_meta"].copy() + return new_node_struct + class NodeReplaceManager: """Manages node replacement registrations.""" @@ -25,6 +41,54 @@ class NodeReplaceManager: """Check if a replacement exists for an old node ID.""" return old_node_id in self._replacements + def apply_replacements(self, prompt: dict[str, NodeStruct]): + connections: dict[str, list[tuple[str, str, int]]] = {} + need_replacement: set[str] = set() + for node_number, node_struct in prompt.items(): + class_type = node_struct["class_type"] + # need replacement if not in NODE_CLASS_MAPPINGS and has replacement + if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): + need_replacement.add(node_number) + # keep track of connections + for input_id, input_value in node_struct["inputs"].items(): + if isinstance(input_value, list): + conn_number = input_value[0] + connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) + if len(need_replacement) > 0: + for node_number in need_replacement: + node_struct = prompt[node_number] + class_type = node_struct["class_type"] + replacements = self.get_replacement(class_type) + if replacements is None: + continue + # just use the first replacement + replacement = replacements[0] + new_node_id = replacement.new_node_id + # first, replace node id (class_type) + new_node_struct = copy_node_struct(node_struct, empty_inputs=True) + new_node_struct["class_type"] = new_node_id + # second, replace inputs + if replacement.input_mapping is not None: + for input_map in replacement.input_mapping: + if "set_value" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] + elif "old_id" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] + # finalize input replacement + prompt[node_number] = new_node_struct + # third, replace outputs + if replacement.output_mapping is not None: + # re-mapping outputs requires changing the input values of nodes that receive connections from this one + # so we need to find all nodes that receive connections from this one + if node_number in connections: + for conns in connections[node_number]: + conn_node_number, conn_input_id, old_output_idx = conns + for output_map in replacement.output_mapping: + if output_map["old_idx"] == old_output_idx: + new_output_idx = output_map["new_idx"] + previous_input = prompt[conn_node_number]["inputs"][conn_input_id] + previous_input[1] = new_output_idx + def as_dict(self): """Serialize all replacements to dict.""" return { diff --git a/server.py b/server.py index 362d06e86..8882e43c4 100644 --- a/server.py +++ b/server.py @@ -889,6 +889,8 @@ class PromptServer(): if "partial_execution_targets" in json_data: partial_execution_targets = json_data["partial_execution_targets"] + self.node_replace_manager.apply_replacements(prompt) + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: