From d731cb6ae15660db0bd6d796f94809a651886b0f Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Tue, 17 Mar 2026 20:57:32 -0700 Subject: [PATCH 1/3] feat: auto-register node replacements from custom node JSON files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Custom node authors can now ship a `node_replacements.json` in their repo root to define replacements declaratively. During node loading, ComfyUI reads these files and registers entries via the existing NodeReplaceManager — no Python registration code needed. This enables two use cases: 1. Authors deprecate/rename nodes with a migration path for old workflows 2. Authors offer their nodes as drop-in replacements for other packs --- nodes.py | 50 ++++++ tests/test_node_replacements_json.py | 219 +++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 tests/test_node_replacements_json.py diff --git a/nodes.py b/nodes.py index 03dcc9d4a..ad2c27077 100644 --- a/nodes.py +++ b/nodes.py @@ -2202,6 +2202,54 @@ def get_module_name(module_path: str) -> str: return base_path +def load_node_replacements_json(module_dir: str, module_name: str): + """Load node_replacements.json from a custom node directory and register replacements. + + Custom node authors can ship a node_replacements.json file in their repo root + to define node replacements declaratively, without writing Python registration code. + The file format matches the output of NodeReplace.as_dict(), keyed by old_node_id. + """ + replacements_path = os.path.join(module_dir, "node_replacements.json") + if not os.path.isfile(replacements_path): + return + + try: + with open(replacements_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, dict): + logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") + return + + from server import PromptServer + from comfy_api.latest._io import NodeReplace + + manager = PromptServer.instance.node_replace_manager + count = 0 + for old_node_id, replacements in data.items(): + if not isinstance(replacements, list): + logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") + continue + for entry in replacements: + if not isinstance(entry, dict): + continue + manager.register(NodeReplace( + new_node_id=entry.get("new_node_id", ""), + old_node_id=entry.get("old_node_id", old_node_id), + old_widget_ids=entry.get("old_widget_ids"), + input_mapping=entry.get("input_mapping"), + output_mapping=entry.get("output_mapping"), + )) + count += 1 + + if count > 0: + logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") + except json.JSONDecodeError as e: + logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") + except Exception as e: + logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + + async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: module_name = get_module_name(module_path) if os.path.isfile(module_path): @@ -2226,6 +2274,8 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + load_node_replacements_json(module_dir, module_name) + try: from comfy_config import config_parser diff --git a/tests/test_node_replacements_json.py b/tests/test_node_replacements_json.py new file mode 100644 index 000000000..ec601f0b3 --- /dev/null +++ b/tests/test_node_replacements_json.py @@ -0,0 +1,219 @@ +"""Tests for auto-registration of node_replacements.json from custom node directories.""" +import json +import os +import tempfile +import unittest +from unittest.mock import MagicMock + +# We can't import nodes.py directly (torch dependency), so we test the +# load_node_replacements_json logic by re-creating it from the same source. +# This validates the JSON parsing and NodeReplace construction logic. + + +class MockNodeReplace: + """Mirrors comfy_api.latest._io.NodeReplace for testing.""" + def __init__(self, new_node_id, old_node_id, old_widget_ids=None, + input_mapping=None, output_mapping=None): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + +def load_node_replacements_json(module_dir, module_name, manager, NodeReplace=MockNodeReplace): + """Standalone version of the function from nodes.py for testing.""" + import logging + replacements_path = os.path.join(module_dir, "node_replacements.json") + if not os.path.isfile(replacements_path): + return + + try: + with open(replacements_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, dict): + logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") + return + + count = 0 + for old_node_id, replacements in data.items(): + if not isinstance(replacements, list): + logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") + continue + for entry in replacements: + if not isinstance(entry, dict): + continue + manager.register(NodeReplace( + new_node_id=entry.get("new_node_id", ""), + old_node_id=entry.get("old_node_id", old_node_id), + old_widget_ids=entry.get("old_widget_ids"), + input_mapping=entry.get("input_mapping"), + output_mapping=entry.get("output_mapping"), + )) + count += 1 + + if count > 0: + logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") + except json.JSONDecodeError as e: + logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") + except Exception as e: + logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + + +class TestLoadNodeReplacementsJson(unittest.TestCase): + """Test auto-registration of node_replacements.json from custom node directories.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.mock_manager = MagicMock() + + def _write_json(self, data): + path = os.path.join(self.tmpdir, "node_replacements.json") + with open(path, "w") as f: + json.dump(data, f) + + def _load(self): + load_node_replacements_json(self.tmpdir, "test-node-pack", self.mock_manager) + + def test_no_file_does_nothing(self): + """No node_replacements.json — should silently do nothing.""" + self._load() + self.mock_manager.register.assert_not_called() + + def test_empty_object(self): + """Empty {} — should do nothing.""" + self._write_json({}) + self._load() + self.mock_manager.register.assert_not_called() + + def test_single_replacement(self): + """Single replacement entry registers correctly.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}], + "output_mapping": [{"new_idx": 0, "old_idx": 0}], + }] + }) + self._load() + self.mock_manager.register.assert_called_once() + registered = self.mock_manager.register.call_args[0][0] + self.assertEqual(registered.new_node_id, "NewNode") + self.assertEqual(registered.old_node_id, "OldNode") + self.assertEqual(registered.input_mapping, [{"new_id": "model", "old_id": "ckpt_name"}]) + self.assertEqual(registered.output_mapping, [{"new_idx": 0, "old_idx": 0}]) + + def test_multiple_replacements(self): + """Multiple old_node_ids each with entries.""" + self._write_json({ + "NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}], + "NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}], + }) + self._load() + self.assertEqual(self.mock_manager.register.call_count, 2) + + def test_multiple_alternatives_for_same_node(self): + """Multiple replacement options for the same old node.""" + self._write_json({ + "OldNode": [ + {"new_node_id": "AltA", "old_node_id": "OldNode"}, + {"new_node_id": "AltB", "old_node_id": "OldNode"}, + ] + }) + self._load() + self.assertEqual(self.mock_manager.register.call_count, 2) + + def test_null_mappings(self): + """Null input/output mappings (trivial replacement).""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": None, + "output_mapping": None, + }] + }) + self._load() + registered = self.mock_manager.register.call_args[0][0] + self.assertIsNone(registered.input_mapping) + self.assertIsNone(registered.output_mapping) + + def test_old_node_id_defaults_to_key(self): + """If old_node_id is missing from entry, uses the dict key.""" + self._write_json({ + "OldNode": [{"new_node_id": "NewNode"}] + }) + self._load() + registered = self.mock_manager.register.call_args[0][0] + self.assertEqual(registered.old_node_id, "OldNode") + + def test_invalid_json_skips(self): + """Invalid JSON file — should warn and skip, not crash.""" + path = os.path.join(self.tmpdir, "node_replacements.json") + with open(path, "w") as f: + f.write("{invalid json") + self._load() + self.mock_manager.register.assert_not_called() + + def test_non_object_json_skips(self): + """JSON array instead of object — should warn and skip.""" + self._write_json([1, 2, 3]) + self._load() + self.mock_manager.register.assert_not_called() + + def test_non_list_value_skips(self): + """Value is not a list — should warn and skip that key.""" + self._write_json({ + "OldNode": "not a list", + "GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}], + }) + self._load() + self.assertEqual(self.mock_manager.register.call_count, 1) + + def test_with_old_widget_ids(self): + """old_widget_ids are passed through.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "old_widget_ids": ["width", "height"], + }] + }) + self._load() + registered = self.mock_manager.register.call_args[0][0] + self.assertEqual(registered.old_widget_ids, ["width", "height"]) + + def test_set_value_in_input_mapping(self): + """input_mapping with set_value entries.""" + self._write_json({ + "OldNode": [{ + "new_node_id": "NewNode", + "old_node_id": "OldNode", + "input_mapping": [ + {"new_id": "method", "set_value": "lanczos"}, + {"new_id": "size", "old_id": "dimension"}, + ], + }] + }) + self._load() + registered = self.mock_manager.register.call_args[0][0] + self.assertEqual(len(registered.input_mapping), 2) + self.assertEqual(registered.input_mapping[0]["set_value"], "lanczos") + self.assertEqual(registered.input_mapping[1]["old_id"], "dimension") + + def test_non_dict_entry_skipped(self): + """Non-dict entries in the list are silently skipped.""" + self._write_json({ + "OldNode": [ + "not a dict", + {"new_node_id": "NewNode", "old_node_id": "OldNode"}, + ] + }) + self._load() + self.assertEqual(self.mock_manager.register.call_count, 1) + + +if __name__ == "__main__": + unittest.main() From 62ec9a32389f4f7e46ffa40d1b153db8e013dc3f Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Mon, 23 Mar 2026 14:47:03 -0700 Subject: [PATCH 2/3] fix: skip single-file nodes and validate new_node_id Two fixes from code review: 1. Only load node_replacements.json from directory-based custom nodes. Single-file .py nodes share a parent dir (custom_nodes/), so checking there would incorrectly pick up a stray file. 2. Skip entries with missing or empty new_node_id instead of registering a replacement pointing to nothing. --- nodes.py | 11 +++++++++-- tests/test_node_replacements_json.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index faf908538..c6e829eac 100644 --- a/nodes.py +++ b/nodes.py @@ -2233,8 +2233,12 @@ def load_node_replacements_json(module_dir: str, module_name: str): for entry in replacements: if not isinstance(entry, dict): continue + new_node_id = entry.get("new_node_id", "") + if not new_node_id: + logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") + continue manager.register(NodeReplace( - new_node_id=entry.get("new_node_id", ""), + new_node_id=new_node_id, old_node_id=entry.get("old_node_id", old_node_id), old_widget_ids=entry.get("old_widget_ids"), input_mapping=entry.get("input_mapping"), @@ -2274,7 +2278,10 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) - load_node_replacements_json(module_dir, module_name) + # Only load node_replacements.json from directory-based custom nodes (proper packs). + # Single-file .py nodes share a parent dir, so checking there would be incorrect. + if os.path.isdir(module_path): + load_node_replacements_json(module_dir, module_name) try: from comfy_config import config_parser diff --git a/tests/test_node_replacements_json.py b/tests/test_node_replacements_json.py index ec601f0b3..101c4caa4 100644 --- a/tests/test_node_replacements_json.py +++ b/tests/test_node_replacements_json.py @@ -44,8 +44,12 @@ def load_node_replacements_json(module_dir, module_name, manager, NodeReplace=Mo for entry in replacements: if not isinstance(entry, dict): continue + new_node_id = entry.get("new_node_id", "") + if not new_node_id: + logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") + continue manager.register(NodeReplace( - new_node_id=entry.get("new_node_id", ""), + new_node_id=new_node_id, old_node_id=entry.get("old_node_id", old_node_id), old_widget_ids=entry.get("old_widget_ids"), input_mapping=entry.get("input_mapping"), @@ -203,6 +207,20 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): self.assertEqual(registered.input_mapping[0]["set_value"], "lanczos") self.assertEqual(registered.input_mapping[1]["old_id"], "dimension") + def test_missing_new_node_id_skipped(self): + """Entry without new_node_id is skipped.""" + self._write_json({ + "OldNode": [ + {"old_node_id": "OldNode"}, + {"new_node_id": "", "old_node_id": "OldNode"}, + {"new_node_id": "ValidNew", "old_node_id": "OldNode"}, + ] + }) + self._load() + self.assertEqual(self.mock_manager.register.call_count, 1) + registered = self.mock_manager.register.call_args[0][0] + self.assertEqual(registered.new_node_id, "ValidNew") + def test_non_dict_entry_skipped(self): """Non-dict entries in the list are silently skipped.""" self._write_json({ From 9837dd368ac289ef19c1d02d513a3a4ec4d67b80 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Wed, 25 Mar 2026 22:12:23 -0700 Subject: [PATCH 3/3] refactor: move load_from_json into NodeReplaceManager Address review feedback from Kosinkadink: 1. Move JSON loading logic from nodes.py into NodeReplaceManager as load_from_json() method for better encapsulation and testability 2. Tests now exercise the real NodeReplaceManager (no duplicated logic) 3. Defer `import nodes` in apply_replacements to avoid torch at import 4. nodes.py call site simplified to one line: PromptServer.instance.node_replace_manager.load_from_json(...) --- app/node_replace_manager.py | 60 ++++++++++- nodes.py | 55 +---------- tests/test_node_replacements_json.py | 142 ++++++++++++--------------- 3 files changed, 122 insertions(+), 135 deletions(-) diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index d9aab5b22..8be55c10c 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -1,5 +1,9 @@ from __future__ import annotations +import json +import logging +import os + from aiohttp import web from typing import TYPE_CHECKING, TypedDict @@ -7,7 +11,6 @@ if TYPE_CHECKING: from comfy_api.latest._io_public import NodeReplace from comfy_execution.graph_utils import is_link -import nodes class NodeStruct(TypedDict): inputs: dict[str, str | int | float | bool | tuple[str, int]] @@ -43,6 +46,7 @@ class NodeReplaceManager: return old_node_id in self._replacements def apply_replacements(self, prompt: dict[str, NodeStruct]): + import nodes connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): @@ -94,6 +98,60 @@ class NodeReplaceManager: previous_input = prompt[conn_node_number]["inputs"][conn_input_id] previous_input[1] = new_output_idx + def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None): + """Load node_replacements.json from a custom node directory and register replacements. + + Custom node authors can ship a node_replacements.json file in their repo root + to define node replacements declaratively. The file format matches the output + of NodeReplace.as_dict(), keyed by old_node_id. + + Fail-open: all errors are logged and skipped so a malformed file never + prevents the custom node from loading. + """ + replacements_path = os.path.join(module_dir, "node_replacements.json") + if not os.path.isfile(replacements_path): + return + + try: + with open(replacements_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, dict): + logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") + return + + if _node_replace_class is None: + from comfy_api.latest._io import NodeReplace + _node_replace_class = NodeReplace + + count = 0 + for old_node_id, replacements in data.items(): + if not isinstance(replacements, list): + logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") + continue + for entry in replacements: + if not isinstance(entry, dict): + continue + new_node_id = entry.get("new_node_id", "") + if not new_node_id: + logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") + continue + self.register(_node_replace_class( + new_node_id=new_node_id, + old_node_id=entry.get("old_node_id", old_node_id), + old_widget_ids=entry.get("old_widget_ids"), + input_mapping=entry.get("input_mapping"), + output_mapping=entry.get("output_mapping"), + )) + count += 1 + + if count > 0: + logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") + except json.JSONDecodeError as e: + logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") + except Exception as e: + logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + def as_dict(self): """Serialize all replacements to dict.""" return { diff --git a/nodes.py b/nodes.py index c6e829eac..b3edf71b1 100644 --- a/nodes.py +++ b/nodes.py @@ -2202,58 +2202,6 @@ def get_module_name(module_path: str) -> str: return base_path -def load_node_replacements_json(module_dir: str, module_name: str): - """Load node_replacements.json from a custom node directory and register replacements. - - Custom node authors can ship a node_replacements.json file in their repo root - to define node replacements declaratively, without writing Python registration code. - The file format matches the output of NodeReplace.as_dict(), keyed by old_node_id. - """ - replacements_path = os.path.join(module_dir, "node_replacements.json") - if not os.path.isfile(replacements_path): - return - - try: - with open(replacements_path, "r", encoding="utf-8") as f: - data = json.load(f) - - if not isinstance(data, dict): - logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") - return - - from server import PromptServer - from comfy_api.latest._io import NodeReplace - - manager = PromptServer.instance.node_replace_manager - count = 0 - for old_node_id, replacements in data.items(): - if not isinstance(replacements, list): - logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") - continue - for entry in replacements: - if not isinstance(entry, dict): - continue - new_node_id = entry.get("new_node_id", "") - if not new_node_id: - logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") - continue - manager.register(NodeReplace( - new_node_id=new_node_id, - old_node_id=entry.get("old_node_id", old_node_id), - old_widget_ids=entry.get("old_widget_ids"), - input_mapping=entry.get("input_mapping"), - output_mapping=entry.get("output_mapping"), - )) - count += 1 - - if count > 0: - logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") - except json.JSONDecodeError as e: - logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") - except Exception as e: - logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") - - async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: module_name = get_module_name(module_path) if os.path.isfile(module_path): @@ -2281,7 +2229,8 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom # Only load node_replacements.json from directory-based custom nodes (proper packs). # Single-file .py nodes share a parent dir, so checking there would be incorrect. if os.path.isdir(module_path): - load_node_replacements_json(module_dir, module_name) + from server import PromptServer + PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name) try: from comfy_config import config_parser diff --git a/tests/test_node_replacements_json.py b/tests/test_node_replacements_json.py index 101c4caa4..c773c3302 100644 --- a/tests/test_node_replacements_json.py +++ b/tests/test_node_replacements_json.py @@ -1,17 +1,15 @@ -"""Tests for auto-registration of node_replacements.json from custom node directories.""" +"""Tests for NodeReplaceManager.load_from_json — auto-registration of +node_replacements.json from custom node directories.""" import json import os import tempfile import unittest -from unittest.mock import MagicMock -# We can't import nodes.py directly (torch dependency), so we test the -# load_node_replacements_json logic by re-creating it from the same source. -# This validates the JSON parsing and NodeReplace construction logic. +from app.node_replace_manager import NodeReplaceManager -class MockNodeReplace: - """Mirrors comfy_api.latest._io.NodeReplace for testing.""" +class SimpleNodeReplace: + """Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import).""" def __init__(self, new_node_id, old_node_id, old_widget_ids=None, input_mapping=None, output_mapping=None): self.new_node_id = new_node_id @@ -20,57 +18,22 @@ class MockNodeReplace: self.input_mapping = input_mapping self.output_mapping = output_mapping - -def load_node_replacements_json(module_dir, module_name, manager, NodeReplace=MockNodeReplace): - """Standalone version of the function from nodes.py for testing.""" - import logging - replacements_path = os.path.join(module_dir, "node_replacements.json") - if not os.path.isfile(replacements_path): - return - - try: - with open(replacements_path, "r", encoding="utf-8") as f: - data = json.load(f) - - if not isinstance(data, dict): - logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.") - return - - count = 0 - for old_node_id, replacements in data.items(): - if not isinstance(replacements, list): - logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.") - continue - for entry in replacements: - if not isinstance(entry, dict): - continue - new_node_id = entry.get("new_node_id", "") - if not new_node_id: - logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.") - continue - manager.register(NodeReplace( - new_node_id=new_node_id, - old_node_id=entry.get("old_node_id", old_node_id), - old_widget_ids=entry.get("old_widget_ids"), - input_mapping=entry.get("input_mapping"), - output_mapping=entry.get("output_mapping"), - )) - count += 1 - - if count > 0: - logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json") - except json.JSONDecodeError as e: - logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}") - except Exception as e: - logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}") + def as_dict(self): + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": list(self.input_mapping) if self.input_mapping else None, + "output_mapping": list(self.output_mapping) if self.output_mapping else None, + } -class TestLoadNodeReplacementsJson(unittest.TestCase): +class TestLoadFromJson(unittest.TestCase): """Test auto-registration of node_replacements.json from custom node directories.""" def setUp(self): self.tmpdir = tempfile.mkdtemp() - self.mock_manager = MagicMock() + self.manager = NodeReplaceManager() def _write_json(self, data): path = os.path.join(self.tmpdir, "node_replacements.json") @@ -78,18 +41,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): json.dump(data, f) def _load(self): - load_node_replacements_json(self.tmpdir, "test-node-pack", self.mock_manager) + self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace) def test_no_file_does_nothing(self): """No node_replacements.json — should silently do nothing.""" self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_empty_object(self): """Empty {} — should do nothing.""" self._write_json({}) self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_single_replacement(self): """Single replacement entry registers correctly.""" @@ -102,12 +65,14 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - self.mock_manager.register.assert_called_once() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.new_node_id, "NewNode") - self.assertEqual(registered.old_node_id, "OldNode") - self.assertEqual(registered.input_mapping, [{"new_id": "model", "old_id": "ckpt_name"}]) - self.assertEqual(registered.output_mapping, [{"new_idx": 0, "old_idx": 0}]) + result = self.manager.as_dict() + self.assertIn("OldNode", result) + self.assertEqual(len(result["OldNode"]), 1) + entry = result["OldNode"][0] + self.assertEqual(entry["new_node_id"], "NewNode") + self.assertEqual(entry["old_node_id"], "OldNode") + self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}]) + self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}]) def test_multiple_replacements(self): """Multiple old_node_ids each with entries.""" @@ -116,7 +81,10 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}], }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 2) + result = self.manager.as_dict() + self.assertEqual(len(result), 2) + self.assertIn("NodeA", result) + self.assertIn("NodeC", result) def test_multiple_alternatives_for_same_node(self): """Multiple replacement options for the same old node.""" @@ -127,7 +95,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 2) + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 2) def test_null_mappings(self): """Null input/output mappings (trivial replacement).""" @@ -140,9 +109,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertIsNone(registered.input_mapping) - self.assertIsNone(registered.output_mapping) + entry = self.manager.as_dict()["OldNode"][0] + self.assertIsNone(entry["input_mapping"]) + self.assertIsNone(entry["output_mapping"]) def test_old_node_id_defaults_to_key(self): """If old_node_id is missing from entry, uses the dict key.""" @@ -150,8 +119,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "OldNode": [{"new_node_id": "NewNode"}] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.old_node_id, "OldNode") + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_node_id"], "OldNode") def test_invalid_json_skips(self): """Invalid JSON file — should warn and skip, not crash.""" @@ -159,13 +128,13 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): with open(path, "w") as f: f.write("{invalid json") self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_non_object_json_skips(self): """JSON array instead of object — should warn and skip.""" self._write_json([1, 2, 3]) self._load() - self.mock_manager.register.assert_not_called() + self.assertEqual(self.manager.as_dict(), {}) def test_non_list_value_skips(self): """Value is not a list — should warn and skip that key.""" @@ -174,7 +143,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): "GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}], }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) + result = self.manager.as_dict() + self.assertNotIn("OldNode", result) + self.assertIn("GoodNode", result) def test_with_old_widget_ids(self): """old_widget_ids are passed through.""" @@ -186,8 +157,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.old_widget_ids, ["width", "height"]) + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(entry["old_widget_ids"], ["width", "height"]) def test_set_value_in_input_mapping(self): """input_mapping with set_value entries.""" @@ -202,10 +173,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): }] }) self._load() - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(len(registered.input_mapping), 2) - self.assertEqual(registered.input_mapping[0]["set_value"], "lanczos") - self.assertEqual(registered.input_mapping[1]["old_id"], "dimension") + entry = self.manager.as_dict()["OldNode"][0] + self.assertEqual(len(entry["input_mapping"]), 2) def test_missing_new_node_id_skipped(self): """Entry without new_node_id is skipped.""" @@ -217,9 +186,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) - registered = self.mock_manager.register.call_args[0][0] - self.assertEqual(registered.new_node_id, "ValidNew") + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew") def test_non_dict_entry_skipped(self): """Non-dict entries in the list are silently skipped.""" @@ -230,7 +199,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase): ] }) self._load() - self.assertEqual(self.mock_manager.register.call_count, 1) + result = self.manager.as_dict() + self.assertEqual(len(result["OldNode"]), 1) + + def test_has_replacement_after_load(self): + """Manager reports has_replacement correctly after JSON load.""" + self._write_json({ + "OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}], + }) + self.assertFalse(self.manager.has_replacement("OldNode")) + self._load() + self.assertTrue(self.manager.has_replacement("OldNode")) + self.assertFalse(self.manager.has_replacement("UnknownNode")) if __name__ == "__main__":