diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index d9aab5b22..8be55c10c 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -1,5 +1,9 @@ from __future__ import annotations +import json +import logging +import os + from aiohttp import web from typing import TYPE_CHECKING, TypedDict @@ -7,7 +11,6 @@ if TYPE_CHECKING: from comfy_api.latest._io_public import NodeReplace from comfy_execution.graph_utils import is_link -import nodes class NodeStruct(TypedDict): inputs: dict[str, str | int | float | bool | tuple[str, int]] @@ -43,6 +46,7 @@ class NodeReplaceManager: return old_node_id in self._replacements def apply_replacements(self, prompt: dict[str, NodeStruct]): + import nodes connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): @@ -94,6 +98,60 @@ class NodeReplaceManager: previous_input = prompt[conn_node_number]["inputs"][conn_input_id] previous_input[1] = new_output_idx + def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None): + """Load node_replacements.json from a custom node directory and register replacements. + + Custom node authors can ship a node_replacements.json file in their repo root + to define node replacements declaratively. The file format matches the output + of NodeReplace.as_dict(), keyed by old_node_id. + + Fail-open: all errors are logged and skipped so a malformed file never + prevents the custom node from loading. + """ + replacements_path = os.path.join(module_dir, "node_replacements.json") + if not os.path.isfile(replacements_path): + return + + try: + with open(replacements_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, dict): + logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") + return + + if _node_replace_class is None: + from comfy_api.latest._io import NodeReplace + _node_replace_class = NodeReplace + + count = 0 + for old_node_id, replacements in data.items(): + if not isinstance(replacements, list): + logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") + continue + for entry in replacements: + if not isinstance(entry, dict): + continue + new_node_id = entry.get("new_node_id", "") + if not new_node_id: + logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") + continue + self.register(_node_replace_class( + new_node_id=new_node_id, + old_node_id=entry.get("old_node_id", old_node_id), + old_widget_ids=entry.get("old_widget_ids"), + input_mapping=entry.get("input_mapping"), + output_mapping=entry.get("output_mapping"), + )) + count += 1 + + if count > 0: + logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") + except json.JSONDecodeError as e: + logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") + except Exception as e: + logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + def as_dict(self): """Serialize all replacements to dict.""" return { diff --git a/nodes.py b/nodes.py index 299b3d758..4eb82cd45 100644 --- a/nodes.py +++ b/nodes.py @@ -2228,6 +2228,12 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + # Only load node_replacements.json from directory-based custom nodes (proper packs). + # Single-file .py nodes share a parent dir, so checking there would be incorrect. + if os.path.isdir(module_path): + from server import PromptServer + PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name) + try: from comfy_config import config_parser diff --git a/tests/test_node_replacements_json.py b/tests/test_node_replacements_json.py new file mode 100644 index 000000000..c773c3302 --- /dev/null +++ b/tests/test_node_replacements_json.py @@ -0,0 +1,217 @@ +"""Tests for NodeReplaceManager.load_from_json — auto-registration of +node_replacements.json from custom node directories.""" +import json +import os +import tempfile +import unittest + +from app.node_replace_manager import NodeReplaceManager + + +class SimpleNodeReplace: + """Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import).""" + def __init__(self, new_node_id, old_node_id, old_widget_ids=None, + input_mapping=None, output_mapping=None): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + def as_dict(self): + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": list(self.input_mapping) if self.input_mapping else None, + "output_mapping": list(self.output_mapping) if self.output_mapping else None, + } + + +class TestLoadFromJson(unittest.TestCase): + """Test auto-registration of node_replacements.json from custom node directories.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.manager = NodeReplaceManager() + + def _write_json(self, data): + path = os.path.join(self.tmpdir, "node_replacements.json") + with open(path, "w") as f: + json.dump(data, f) + + def _load(self): + self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace) + + def test_no_file_does_nothing(self): + """No node_replacements.json — should silently do nothing.""" + self._load() + self.assertEqual(self.manager.as_dict(), {}) + + def test_empty_object(self): + """Empty {} — should do nothing.""" + self._write_json({}) + self._load() + self.assertEqual(self.manager.as_dict(), {}) + + def test_single_replacement(self): + """Single replacement entry registers correctly.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}], + "output_mapping": [{"new_idx": 0, "old_idx": 0}], + }] + }) + self._load() + result = self.manager.as_dict() + self.assertIn("OldNode", result) + self.assertEqual(len(result["OldNode"]), 1) + entry = result["OldNode"][0] + self.assertEqual(entry["new_node_id"], "NewNode") + self.assertEqual(entry["old_node_id"], "OldNode") + self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}]) + self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}]) + + def test_multiple_replacements(self): + """Multiple old_node_ids each with entries.""" + self._write_json({ + "NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}], + "NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}], + }) + self._load() + result = self.manager.as_dict() + self.assertEqual(len(result), 2) + self.assertIn("NodeA", result) + self.assertIn("NodeC", result) + + def test_multiple_alternatives_for_same_node(self): + """Multiple replacement options for the same old node.""" + self._write_json({ + "OldNode": [ + {"new_node_id": "AltA", "old_node_id": "OldNode"}, + {"new_node_id": "AltB", "old_node_id": "OldNode"}, + ] + }) + self._load() + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 2) + + def test_null_mappings(self): + """Null input/output mappings (trivial replacement).""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": None, + "output_mapping": None, + }] + }) + self._load() + entry = self.manager.as_dict()["OldNode"][0] + self.assertIsNone(entry["input_mapping"]) + self.assertIsNone(entry["output_mapping"]) + + def test_old_node_id_defaults_to_key(self): + """If old_node_id is missing from entry, uses the dict key.""" + self._write_json({ + "OldNode": [{"new_node_id": "NewNode"}] + }) + self._load() + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_node_id"], "OldNode") + + def test_invalid_json_skips(self): + """Invalid JSON file — should warn and skip, not crash.""" + path = os.path.join(self.tmpdir, "node_replacements.json") + with open(path, "w") as f: + f.write("{invalid json") + self._load() + self.assertEqual(self.manager.as_dict(), {}) + + def test_non_object_json_skips(self): + """JSON array instead of object — should warn and skip.""" + self._write_json([1, 2, 3]) + self._load() + self.assertEqual(self.manager.as_dict(), {}) + + def test_non_list_value_skips(self): + """Value is not a list — should warn and skip that key.""" + self._write_json({ + "OldNode": "not a list", + "GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}], + }) + self._load() + result = self.manager.as_dict() + self.assertNotIn("OldNode", result) + self.assertIn("GoodNode", result) + + def test_with_old_widget_ids(self): + """old_widget_ids are passed through.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "old_widget_ids": ["width", "height"], + }] + }) + self._load() + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_widget_ids"], ["width", "height"]) + + def test_set_value_in_input_mapping(self): + """input_mapping with set_value entries.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": [ + {"new_id": "method", "set_value": "lanczos"}, + {"new_id": "size", "old_id": "dimension"}, + ], + }] + }) + self._load() + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(len(entry["input_mapping"]), 2) + + def test_missing_new_node_id_skipped(self): + """Entry without new_node_id is skipped.""" + self._write_json({ + "OldNode": [ + {"old_node_id": "OldNode"}, + {"new_node_id": "", "old_node_id": "OldNode"}, + {"new_node_id": "ValidNew", "old_node_id": "OldNode"}, + ] + }) + self._load() + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew") + + def test_non_dict_entry_skipped(self): + """Non-dict entries in the list are silently skipped.""" + self._write_json({ + "OldNode": [ + "not a dict", + {"new_node_id": "NewNode", "old_node_id": "OldNode"}, + ] + }) + self._load() + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + + def test_has_replacement_after_load(self): + """Manager reports has_replacement correctly after JSON load.""" + self._write_json({ + "OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}], + }) + self.assertFalse(self.manager.has_replacement("OldNode")) + self._load() + self.assertTrue(self.manager.has_replacement("OldNode")) + self.assertFalse(self.manager.has_replacement("UnknownNode")) + + +if __name__ == "__main__": + unittest.main()