diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index d9aab5b22..8be55c10c 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -1,5 +1,9 @@ from __future__ import annotations +import json +import logging +import os + from aiohttp import web from typing import TYPE_CHECKING, TypedDict @@ -7,7 +11,6 @@ if TYPE_CHECKING: from comfy_api.latest._io_public import NodeReplace from comfy_execution.graph_utils import is_link -import nodes class NodeStruct(TypedDict): inputs: dict[str, str | int | float | bool | tuple[str, int]] @@ -43,6 +46,7 @@ class NodeReplaceManager: return old_node_id in self._replacements def apply_replacements(self, prompt: dict[str, NodeStruct]): + import nodes connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): @@ -94,6 +98,60 @@ class NodeReplaceManager: previous_input = prompt[conn_node_number]["inputs"][conn_input_id] previous_input[1] = new_output_idx + def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None): + """Load node_replacements.json from a custom node directory and register replacements. + + Custom node authors can ship a node_replacements.json file in their repo root + to define node replacements declaratively. The file format matches the output + of NodeReplace.as_dict(), keyed by old_node_id. + + Fail-open: all errors are logged and skipped so a malformed file never + prevents the custom node from loading. + """ + replacements_path = os.path.join(module_dir, "node_replacements.json") + if not os.path.isfile(replacements_path): + return + + try: + with open(replacements_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, dict): + logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") + return + + if _node_replace_class is None: + from comfy_api.latest._io import NodeReplace + _node_replace_class = NodeReplace + + count = 0 + for old_node_id, replacements in data.items(): + if not isinstance(replacements, list): + logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") + continue + for entry in replacements: + if not isinstance(entry, dict): + continue + new_node_id = entry.get("new_node_id", "") + if not new_node_id: + logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") + continue + self.register(_node_replace_class( + new_node_id=new_node_id, + old_node_id=entry.get("old_node_id", old_node_id), + old_widget_ids=entry.get("old_widget_ids"), + input_mapping=entry.get("input_mapping"), + output_mapping=entry.get("output_mapping"), + )) + count += 1 + + if count > 0: + logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") + except json.JSONDecodeError as e: + logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") + except Exception as e: + logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + def as_dict(self): """Serialize all replacements to dict.""" return { diff --git a/nodes.py b/nodes.py index c6e829eac..b3edf71b1 100644 --- a/nodes.py +++ b/nodes.py @@ -2202,58 +2202,6 @@ def get_module_name(module_path: str) -> str: return base_path -def load_node_replacements_json(module_dir: str, module_name: str): - """Load node_replacements.json from a custom node directory and register replacements. - - Custom node authors can ship a node_replacements.json file in their repo root - to define node replacements declaratively, without writing Python registration code. - The file format matches the output of NodeReplace.as_dict(), keyed by old_node_id. - """ - replacements_path = os.path.join(module_dir, "node_replacements.json") - if not os.path.isfile(replacements_path): - return - - try: - with open(replacements_path, "r", encoding="utf-8") as f: - data = json.load(f) - - if not isinstance(data, dict): - logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") - return - - from server import PromptServer - from comfy_api.latest._io import NodeReplace - - manager = PromptServer.instance.node_replace_manager - count = 0 - for old_node_id, replacements in data.items(): - if not isinstance(replacements, list): - logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") - continue - for entry in replacements: - if not isinstance(entry, dict): - continue - new_node_id = entry.get("new_node_id", "") - if not new_node_id: - logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") - continue - manager.register(NodeReplace( - new_node_id=new_node_id, - old_node_id=entry.get("old_node_id", old_node_id), - old_widget_ids=entry.get("old_widget_ids"), - input_mapping=entry.get("input_mapping"), - output_mapping=entry.get("output_mapping"), - )) - count += 1 - - if count > 0: - logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") - except json.JSONDecodeError as e: - logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") - except Exception as e: - logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") - - async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: module_name = get_module_name(module_path) if os.path.isfile(module_path): @@ -2281,7 +2229,8 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom # Only load node_replacements.json from directory-based custom nodes (proper packs). # Single-file .py nodes share a parent dir, so checking there would be incorrect. if os.path.isdir(module_path): - load_node_replacements_json(module_dir, module_name) + from server import PromptServer + PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name) try: from comfy_config import config_parser diff --git a/tests/test_node_replacements_json.py b/tests/test_node_replacements_json.py index 101c4caa4..c773c3302 100644 --- a/tests/test_node_replacements_json.py +++ b/tests/test_node_replacements_json.py @@ -1,17 +1,15 @@ -"""Tests for auto-registration of node_replacements.json from custom node directories.""" +"""Tests for NodeReplaceManager.load_from_json — auto-registration of +node_replacements.json from custom node directories.""" import json import os import tempfile import unittest -from unittest.mock import MagicMock -# We can't import nodes.py directly (torch dependency), so we test the -# load_node_replacements_json logic by re-creating it from the same source. -# This validates the JSON parsing and NodeReplace construction logic. +from app.node_replace_manager import NodeReplaceManager -class MockNodeReplace: - """Mirrors comfy_api.latest._io.NodeReplace for testing.""" +class SimpleNodeReplace: + """Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import).""" def __init__(self, new_node_id, old_node_id, old_widget_ids=None, input_mapping=None, output_mapping=None): self.new_node_id = new_node_id @@ -20,57 +18,22 @@ class MockNodeReplace: self.input_mapping = input_mapping self.output_mapping = output_mapping - -def load_node_replacements_json(module_dir, module_name, manager, NodeReplace=MockNodeReplace): - """Standalone version of the function from nodes.py for testing.""" - import logging - replacements_path = os.path.join(module_dir, "node_replacements.json") - if not os.path.isfile(replacements_path): - return - - try: - with open(replacements_path, "r", encoding="utf-8") as f: - data = json.load(f) - - if not isinstance(data, dict): - logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") - return - - count = 0 - for old_node_id, replacements in data.items(): - if not isinstance(replacements, list): - logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") - continue - for entry in replacements: - if not isinstance(entry, dict): - continue - new_node_id = entry.get("new_node_id", "") - if not new_node_id: - logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") - continue - manager.register(NodeReplace( - new_node_id=new_node_id, - old_node_id=entry.get("old_node_id", old_node_id), - old_widget_ids=entry.get("old_widget_ids"), - input_mapping=entry.get("input_mapping"), - output_mapping=entry.get("output_mapping"), - )) - count += 1 - - if count > 0: - logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") - except json.JSONDecodeError as e: - logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") - except Exception as e: - logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + def as_dict(self): + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": list(self.input_mapping) if self.input_mapping else None, + "output_mapping": list(self.output_mapping) if self.output_mapping else None, + } -class TestLoadNodeReplacementsJson(unittest.TestCase): +class TestLoadFromJson(unittest.TestCase): """Test auto-registration of node_replacements.json from custom node directories.""" def setUp(self): self.tmpdir = tempfile.mkdtemp() - self.mock_manager = MagicMock() + self.manager = NodeReplaceManager() def _write_json(self, data): path = os.path.join(self.tmpdir, "node_replacements.json") @@ -78,18 +41,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): json.dump(data, f) def _load(self): - load_node_replacements_json(self.tmpdir, "test-node-pack", self.mock_manager) + self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace) def test_no_file_does_nothing(self): """No node_replacements.json — should silently do nothing.""" self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_empty_object(self): """Empty {} — should do nothing.""" self._write_json({}) self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_single_replacement(self): """Single replacement entry registers correctly.""" @@ -102,12 +65,14 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - self.mock_manager.register.assert_called_once() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.new_node_id, "NewNode") - self.assertEqual(registered.old_node_id, "OldNode") - self.assertEqual(registered.input_mapping, [{"new_id": "model", "old_id": "ckpt_name"}]) - self.assertEqual(registered.output_mapping, [{"new_idx": 0, "old_idx": 0}]) + result = self.manager.as_dict() + self.assertIn("OldNode", result) + self.assertEqual(len(result["OldNode"]), 1) + entry = result["OldNode"][0] + self.assertEqual(entry["new_node_id"], "NewNode") + self.assertEqual(entry["old_node_id"], "OldNode") + self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}]) + self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}]) def test_multiple_replacements(self): """Multiple old_node_ids each with entries.""" @@ -116,7 +81,10 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}], }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 2) + result = self.manager.as_dict() + self.assertEqual(len(result), 2) + self.assertIn("NodeA", result) + self.assertIn("NodeC", result) def test_multiple_alternatives_for_same_node(self): """Multiple replacement options for the same old node.""" @@ -127,7 +95,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 2) + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 2) def test_null_mappings(self): """Null input/output mappings (trivial replacement).""" @@ -140,9 +109,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertIsNone(registered.input_mapping) - self.assertIsNone(registered.output_mapping) + entry = self.manager.as_dict()["OldNode"][0] + self.assertIsNone(entry["input_mapping"]) + self.assertIsNone(entry["output_mapping"]) def test_old_node_id_defaults_to_key(self): """If old_node_id is missing from entry, uses the dict key.""" @@ -150,8 +119,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "OldNode": [{"new_node_id": "NewNode"}] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.old_node_id, "OldNode") + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_node_id"], "OldNode") def test_invalid_json_skips(self): """Invalid JSON file — should warn and skip, not crash.""" @@ -159,13 +128,13 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): with open(path, "w") as f: f.write("{invalid json") self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_non_object_json_skips(self): """JSON array instead of object — should warn and skip.""" self._write_json([1, 2, 3]) self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_non_list_value_skips(self): """Value is not a list — should warn and skip that key.""" @@ -174,7 +143,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}], }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) + result = self.manager.as_dict() + self.assertNotIn("OldNode", result) + self.assertIn("GoodNode", result) def test_with_old_widget_ids(self): """old_widget_ids are passed through.""" @@ -186,8 +157,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.old_widget_ids, ["width", "height"]) + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_widget_ids"], ["width", "height"]) def test_set_value_in_input_mapping(self): """input_mapping with set_value entries.""" @@ -202,10 +173,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(len(registered.input_mapping), 2) - self.assertEqual(registered.input_mapping[0]["set_value"], "lanczos") - self.assertEqual(registered.input_mapping[1]["old_id"], "dimension") + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(len(entry["input_mapping"]), 2) def test_missing_new_node_id_skipped(self): """Entry without new_node_id is skipped.""" @@ -217,9 +186,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.new_node_id, "ValidNew") + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew") def test_non_dict_entry_skipped(self): """Non-dict entries in the list are silently skipped.""" @@ -230,7 +199,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + + def test_has_replacement_after_load(self): + """Manager reports has_replacement correctly after JSON load.""" + self._write_json({ + "OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}], + }) + self.assertFalse(self.manager.has_replacement("OldNode")) + self._load() + self.assertTrue(self.manager.has_replacement("OldNode")) + self.assertFalse(self.manager.has_replacement("UnknownNode")) if __name__ == "__main__":