from __future__ import annotations from aiohttp import web from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: from comfy_api.latest._node_replace import NodeReplace from nodes import NODE_CLASS_MAPPINGS class NodeStruct(TypedDict): inputs: dict[str, str | int | float | tuple[str, int]] class_type: str _meta: dict[str, str] def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: new_node_struct = node_struct.copy() if empty_inputs: new_node_struct["inputs"] = {} else: new_node_struct["inputs"] = node_struct["inputs"].copy() new_node_struct["_meta"] = node_struct["_meta"].copy() return new_node_struct class NodeReplaceManager: """Manages node replacement registrations.""" def __init__(self): self._replacements: dict[str, list[NodeReplace]] = {} def register(self, node_replace: NodeReplace): """Register a node replacement mapping.""" self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: """Get replacements for an old node ID.""" return self._replacements.get(old_node_id) def has_replacement(self, old_node_id: str) -> bool: """Check if a replacement exists for an old node ID.""" return old_node_id in self._replacements def apply_replacements(self, prompt: dict[str, NodeStruct]): connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): class_type = node_struct["class_type"] # need replacement if not in NODE_CLASS_MAPPINGS and has replacement if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): need_replacement.add(node_number) # keep track of connections for input_id, input_value in node_struct["inputs"].items(): if isinstance(input_value, list): conn_number = input_value[0] connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) if len(need_replacement) > 0: for node_number in need_replacement: node_struct = prompt[node_number] class_type = node_struct["class_type"] replacements = self.get_replacement(class_type) if replacements is None: continue # just use the first replacement replacement = replacements[0] new_node_id = replacement.new_node_id # first, replace node id (class_type) new_node_struct = copy_node_struct(node_struct, empty_inputs=True) new_node_struct["class_type"] = new_node_id # second, replace inputs if replacement.input_mapping is not None: for input_map in replacement.input_mapping: if "set_value" in input_map: new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] elif "old_id" in input_map: new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] # finalize input replacement prompt[node_number] = new_node_struct # third, replace outputs if replacement.output_mapping is not None: # re-mapping outputs requires changing the input values of nodes that receive connections from this one # so we need to find all nodes that receive connections from this one if node_number in connections: for conns in connections[node_number]: conn_node_number, conn_input_id, old_output_idx = conns for output_map in replacement.output_mapping: if output_map["old_idx"] == old_output_idx: new_output_idx = output_map["new_idx"] previous_input = prompt[conn_node_number]["inputs"][conn_input_id] previous_input[1] = new_output_idx def as_dict(self): """Serialize all replacements to dict.""" return { k: [v.as_dict() for v in v_list] for k, v_list in self._replacements.items() } def add_routes(self, routes): @routes.get("/node_replacements") async def get_node_replacements(request): return web.json_response(self.as_dict())