diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py new file mode 100644 index 000000000..3b1b7ab36 --- /dev/null +++ b/app/node_replace_manager.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from aiohttp import web + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy_api.latest._node_replace import NodeReplace + +REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {} + +def register_node_replacement(node_replace: NodeReplace): + REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace) + +def registered_as_dict(): + return { + k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items() + } + +class NodeReplaceManager: + def add_routes(self, routes): + @routes.get("/node_replacements") + async def get_node_replacements(request): + return web.json_response(registered_as_dict()) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b0fa14ff6..24f6f438b 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui +from . import _node_replace_public as node_replace from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image @@ -130,4 +131,5 @@ __all__ = [ "IO", "ui", "UI", + "node_replace", ] diff --git a/comfy_api/latest/_node_replace.py b/comfy_api/latest/_node_replace.py new file mode 100644 index 000000000..8e90eebb7 --- /dev/null +++ b/comfy_api/latest/_node_replace.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any +import app.node_replace_manager + +def register_node_replacement(node_replace: NodeReplace): + """ + Register node replacement. + """ + app.node_replace_manager.register_node_replacement(node_replace) + + +class NodeReplace: + """ + Defines a possible node replacement, mapping inputs and outputs of the old node to the new node. + + Also supports assigning specific values to the input widgets of the new node. + """ + def __init__(self, + new_node_id: str, + old_node_id: str, + old_widget_ids: list[str] | None=None, + input_mapping: list[InputMap] | None=None, + output_mapping: list[OutputMap] | None=None, + ): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + def as_dict(self): + """ + Create serializable representation of the node replacement. + """ + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None, + "output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None, + } + + +class InputMap: + """ + Map inputs of node replacement. + + Use InputMap.OldId or InputMap.SetValue for mapping purposes. + """ + class _Assign: + def __init__(self, assign_type: str): + self.assign_type = assign_type + + def as_dict(self): + return { + "assign_type": self.assign_type, + } + + class OldId(_Assign): + """ + Connect the input of the old node with given id to new node when replacing. + """ + def __init__(self, old_id: str): + super().__init__("old_id") + self.old_id = old_id + + def as_dict(self): + return super().as_dict() | { + "old_id": self.old_id, + } + + class SetValue(_Assign): + """ + Use the given value for the input of the new node when replacing; assumes input is a widget. + """ + def __init__(self, value: Any): + super().__init__("set_value") + self.value = value + + def as_dict(self): + return super().as_dict() | { + "value": self.value, + } + + def __init__(self, new_id: str, assign: OldId | SetValue): + self.new_id = new_id + self.assign = assign + + def as_dict(self): + return { + "new_id": self.new_id, + "assign": self.assign.as_dict(), + } + + +class OutputMap: + """ + Map outputs of node replacement via indexes, as that's how outputs are stored. + """ + def __init__(self, new_idx: int, old_idx: int): + self.new_idx = new_idx + self.old_idx = old_idx + + def as_dict(self): + return { + "new_idx": self.new_idx, + "old_idx": self.old_idx, + } diff --git a/comfy_api/latest/_node_replace_public.py b/comfy_api/latest/_node_replace_public.py new file mode 100644 index 000000000..cd18168ae --- /dev/null +++ b/comfy_api/latest/_node_replace_public.py @@ -0,0 +1 @@ +from ._node_replace import * # noqa: F403 diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index c4fa1d971..0d4d567da 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,7 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 +from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -46,4 +46,5 @@ __all__ = [ "IO", "ui", "UI", + "node_replace", ] diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index a52a90e2c..573cbb5b7 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -655,6 +655,54 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): batched = batch_masks(values) return io.NodeOutput(batched) + +from comfy_api.latest import node_replace + +def register_replacements(): + register_replacements_longeredge() + register_replacements_batchimages() + register_replacements_upscaleimage() + +def register_replacements_longeredge(): + # No dynamic inputs here + node_replace.register_node_replacement(node_replace.NodeReplace( + new_node_id="ImageScaleToMaxDimension", + old_node_id="ResizeImagesByLongerEdge", + old_widget_ids=["longer_edge"], + input_mapping=[ + node_replace.InputMap(new_id="image", assign=node_replace.InputMap.OldId("images")), + node_replace.InputMap(new_id="largest_size", assign=node_replace.InputMap.OldId("longer_edge")), + node_replace.InputMap(new_id="upscale_method", assign=node_replace.InputMap.SetValue("lanczos")), + ], + # just to test the frontend output_mapping code, does nothing really here + output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)], + )) + +def register_replacements_batchimages(): + # BatchImages node uses Autogrow + node_replace.register_node_replacement(node_replace.NodeReplace( + new_node_id="BatchImagesNode", + old_node_id="ImageBatch", + input_mapping=[ + node_replace.InputMap(new_id="images.image0", assign=node_replace.InputMap.OldId("image1")), + node_replace.InputMap(new_id="images.image1", assign=node_replace.InputMap.OldId("image2")), + ], + )) + +def register_replacements_upscaleimage(): + # ResizeImageMaskNode uses DynamicCombo + node_replace.register_node_replacement(node_replace.NodeReplace( + new_node_id="ResizeImageMaskNode", + old_node_id="ImageScaleBy", + old_widget_ids=["upscale_method", "scale_by"], + input_mapping=[ + node_replace.InputMap(new_id="input", assign=node_replace.InputMap.OldId("image")), + node_replace.InputMap(new_id="resize_type", assign=node_replace.InputMap.SetValue("scale by multiplier")), + node_replace.InputMap(new_id="resize_type.multiplier", assign=node_replace.InputMap.OldId("scale_by")), + node_replace.InputMap(new_id="scale_method", assign=node_replace.InputMap.OldId("upscale_method")), + ], + )) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: diff --git a/server.py b/server.py index 2300393b2..362d06e86 100644 --- a/server.py +++ b/server.py @@ -40,6 +40,7 @@ from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager +from app.node_replace_manager import NodeReplaceManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -204,6 +205,7 @@ class PromptServer(): self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() + self.node_replace_manager = NodeReplaceManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -995,6 +997,7 @@ class PromptServer(): self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) + self.node_replace_manager.add_routes(self.routes) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation.