Node Replacement API (#12014)

This commit is contained in:
Jedrzej Kosinski 2026-02-15 02:12:30 -08:00 committed by GitHub
parent ce4a1ab48d
commit 596ed68691
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 291 additions and 2 deletions

105
app/node_replace_manager.py Normal file
View File

@ -0,0 +1,105 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton):
async def set_progress(
self,
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""

View File

@ -2030,6 +2030,68 @@ class _UIOutput(ABC):
...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [
"FolderType",
"UploadType",
@ -2121,4 +2183,5 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"NodeReplace",
]

View File

@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI
api = ComfyAPI()
async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()

View File

@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
]
import_failed = []

View File

@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
self.node_replace_manager.apply_replacements(prompt)
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.