mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 07:12:30 +08:00
refactor: move load_from_json into NodeReplaceManager
Address review feedback from Kosinkadink: 1. Move JSON loading logic from nodes.py into NodeReplaceManager as load_from_json() method for better encapsulation and testability 2. Tests now exercise the real NodeReplaceManager (no duplicated logic) 3. Defer `import nodes` in apply_replacements to avoid torch at import 4. nodes.py call site simplified to one line: PromptServer.instance.node_replace_manager.load_from_json(...)
This commit is contained in:
parent
62ec9a3238
commit
9837dd368a
@ -1,5 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
@ -7,7 +11,6 @@ if TYPE_CHECKING:
|
|||||||
from comfy_api.latest._io_public import NodeReplace
|
from comfy_api.latest._io_public import NodeReplace
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
import nodes
|
|
||||||
|
|
||||||
class NodeStruct(TypedDict):
|
class NodeStruct(TypedDict):
|
||||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||||
@ -43,6 +46,7 @@ class NodeReplaceManager:
|
|||||||
return old_node_id in self._replacements
|
return old_node_id in self._replacements
|
||||||
|
|
||||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||||
|
import nodes
|
||||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
need_replacement: set[str] = set()
|
need_replacement: set[str] = set()
|
||||||
for node_number, node_struct in prompt.items():
|
for node_number, node_struct in prompt.items():
|
||||||
@ -94,6 +98,60 @@ class NodeReplaceManager:
|
|||||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||||
previous_input[1] = new_output_idx
|
previous_input[1] = new_output_idx
|
||||||
|
|
||||||
|
def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None):
|
||||||
|
"""Load node_replacements.json from a custom node directory and register replacements.
|
||||||
|
|
||||||
|
Custom node authors can ship a node_replacements.json file in their repo root
|
||||||
|
to define node replacements declaratively. The file format matches the output
|
||||||
|
of NodeReplace.as_dict(), keyed by old_node_id.
|
||||||
|
|
||||||
|
Fail-open: all errors are logged and skipped so a malformed file never
|
||||||
|
prevents the custom node from loading.
|
||||||
|
"""
|
||||||
|
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
||||||
|
if not os.path.isfile(replacements_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(replacements_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if _node_replace_class is None:
|
||||||
|
from comfy_api.latest._io import NodeReplace
|
||||||
|
_node_replace_class = NodeReplace
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for old_node_id, replacements in data.items():
|
||||||
|
if not isinstance(replacements, list):
|
||||||
|
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
||||||
|
continue
|
||||||
|
for entry in replacements:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
new_node_id = entry.get("new_node_id", "")
|
||||||
|
if not new_node_id:
|
||||||
|
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
||||||
|
continue
|
||||||
|
self.register(_node_replace_class(
|
||||||
|
new_node_id=new_node_id,
|
||||||
|
old_node_id=entry.get("old_node_id", old_node_id),
|
||||||
|
old_widget_ids=entry.get("old_widget_ids"),
|
||||||
|
input_mapping=entry.get("input_mapping"),
|
||||||
|
output_mapping=entry.get("output_mapping"),
|
||||||
|
))
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
"""Serialize all replacements to dict."""
|
"""Serialize all replacements to dict."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
55
nodes.py
55
nodes.py
@ -2202,58 +2202,6 @@ def get_module_name(module_path: str) -> str:
|
|||||||
return base_path
|
return base_path
|
||||||
|
|
||||||
|
|
||||||
def load_node_replacements_json(module_dir: str, module_name: str):
|
|
||||||
"""Load node_replacements.json from a custom node directory and register replacements.
|
|
||||||
|
|
||||||
Custom node authors can ship a node_replacements.json file in their repo root
|
|
||||||
to define node replacements declaratively, without writing Python registration code.
|
|
||||||
The file format matches the output of NodeReplace.as_dict(), keyed by old_node_id.
|
|
||||||
"""
|
|
||||||
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
|
||||||
if not os.path.isfile(replacements_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(replacements_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
|
||||||
return
|
|
||||||
|
|
||||||
from server import PromptServer
|
|
||||||
from comfy_api.latest._io import NodeReplace
|
|
||||||
|
|
||||||
manager = PromptServer.instance.node_replace_manager
|
|
||||||
count = 0
|
|
||||||
for old_node_id, replacements in data.items():
|
|
||||||
if not isinstance(replacements, list):
|
|
||||||
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
|
||||||
continue
|
|
||||||
for entry in replacements:
|
|
||||||
if not isinstance(entry, dict):
|
|
||||||
continue
|
|
||||||
new_node_id = entry.get("new_node_id", "")
|
|
||||||
if not new_node_id:
|
|
||||||
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
|
||||||
continue
|
|
||||||
manager.register(NodeReplace(
|
|
||||||
new_node_id=new_node_id,
|
|
||||||
old_node_id=entry.get("old_node_id", old_node_id),
|
|
||||||
old_widget_ids=entry.get("old_widget_ids"),
|
|
||||||
input_mapping=entry.get("input_mapping"),
|
|
||||||
output_mapping=entry.get("output_mapping"),
|
|
||||||
))
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||||
module_name = get_module_name(module_path)
|
module_name = get_module_name(module_path)
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
@ -2281,7 +2229,8 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
|||||||
# Only load node_replacements.json from directory-based custom nodes (proper packs).
|
# Only load node_replacements.json from directory-based custom nodes (proper packs).
|
||||||
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
|
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
|
||||||
if os.path.isdir(module_path):
|
if os.path.isdir(module_path):
|
||||||
load_node_replacements_json(module_dir, module_name)
|
from server import PromptServer
|
||||||
|
PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from comfy_config import config_parser
|
from comfy_config import config_parser
|
||||||
|
|||||||
@ -1,17 +1,15 @@
|
|||||||
"""Tests for auto-registration of node_replacements.json from custom node directories."""
|
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
|
||||||
|
node_replacements.json from custom node directories."""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
# We can't import nodes.py directly (torch dependency), so we test the
|
from app.node_replace_manager import NodeReplaceManager
|
||||||
# load_node_replacements_json logic by re-creating it from the same source.
|
|
||||||
# This validates the JSON parsing and NodeReplace construction logic.
|
|
||||||
|
|
||||||
|
|
||||||
class MockNodeReplace:
|
class SimpleNodeReplace:
|
||||||
"""Mirrors comfy_api.latest._io.NodeReplace for testing."""
|
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
|
||||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||||
input_mapping=None, output_mapping=None):
|
input_mapping=None, output_mapping=None):
|
||||||
self.new_node_id = new_node_id
|
self.new_node_id = new_node_id
|
||||||
@ -20,57 +18,22 @@ class MockNodeReplace:
|
|||||||
self.input_mapping = input_mapping
|
self.input_mapping = input_mapping
|
||||||
self.output_mapping = output_mapping
|
self.output_mapping = output_mapping
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
def load_node_replacements_json(module_dir, module_name, manager, NodeReplace=MockNodeReplace):
|
return {
|
||||||
"""Standalone version of the function from nodes.py for testing."""
|
"new_node_id": self.new_node_id,
|
||||||
import logging
|
"old_node_id": self.old_node_id,
|
||||||
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
"old_widget_ids": self.old_widget_ids,
|
||||||
if not os.path.isfile(replacements_path):
|
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||||
return
|
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||||
|
}
|
||||||
try:
|
|
||||||
with open(replacements_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
|
||||||
return
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for old_node_id, replacements in data.items():
|
|
||||||
if not isinstance(replacements, list):
|
|
||||||
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
|
||||||
continue
|
|
||||||
for entry in replacements:
|
|
||||||
if not isinstance(entry, dict):
|
|
||||||
continue
|
|
||||||
new_node_id = entry.get("new_node_id", "")
|
|
||||||
if not new_node_id:
|
|
||||||
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
|
||||||
continue
|
|
||||||
manager.register(NodeReplace(
|
|
||||||
new_node_id=new_node_id,
|
|
||||||
old_node_id=entry.get("old_node_id", old_node_id),
|
|
||||||
old_widget_ids=entry.get("old_widget_ids"),
|
|
||||||
input_mapping=entry.get("input_mapping"),
|
|
||||||
output_mapping=entry.get("output_mapping"),
|
|
||||||
))
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadNodeReplacementsJson(unittest.TestCase):
|
class TestLoadFromJson(unittest.TestCase):
|
||||||
"""Test auto-registration of node_replacements.json from custom node directories."""
|
"""Test auto-registration of node_replacements.json from custom node directories."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdir = tempfile.mkdtemp()
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
self.mock_manager = MagicMock()
|
self.manager = NodeReplaceManager()
|
||||||
|
|
||||||
def _write_json(self, data):
|
def _write_json(self, data):
|
||||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||||
@ -78,18 +41,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
load_node_replacements_json(self.tmpdir, "test-node-pack", self.mock_manager)
|
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
|
||||||
|
|
||||||
def test_no_file_does_nothing(self):
|
def test_no_file_does_nothing(self):
|
||||||
"""No node_replacements.json — should silently do nothing."""
|
"""No node_replacements.json — should silently do nothing."""
|
||||||
self._load()
|
self._load()
|
||||||
self.mock_manager.register.assert_not_called()
|
self.assertEqual(self.manager.as_dict(), {})
|
||||||
|
|
||||||
def test_empty_object(self):
|
def test_empty_object(self):
|
||||||
"""Empty {} — should do nothing."""
|
"""Empty {} — should do nothing."""
|
||||||
self._write_json({})
|
self._write_json({})
|
||||||
self._load()
|
self._load()
|
||||||
self.mock_manager.register.assert_not_called()
|
self.assertEqual(self.manager.as_dict(), {})
|
||||||
|
|
||||||
def test_single_replacement(self):
|
def test_single_replacement(self):
|
||||||
"""Single replacement entry registers correctly."""
|
"""Single replacement entry registers correctly."""
|
||||||
@ -102,12 +65,14 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.mock_manager.register.assert_called_once()
|
result = self.manager.as_dict()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
self.assertIn("OldNode", result)
|
||||||
self.assertEqual(registered.new_node_id, "NewNode")
|
self.assertEqual(len(result["OldNode"]), 1)
|
||||||
self.assertEqual(registered.old_node_id, "OldNode")
|
entry = result["OldNode"][0]
|
||||||
self.assertEqual(registered.input_mapping, [{"new_id": "model", "old_id": "ckpt_name"}])
|
self.assertEqual(entry["new_node_id"], "NewNode")
|
||||||
self.assertEqual(registered.output_mapping, [{"new_idx": 0, "old_idx": 0}])
|
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||||
|
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
|
||||||
|
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
|
||||||
|
|
||||||
def test_multiple_replacements(self):
|
def test_multiple_replacements(self):
|
||||||
"""Multiple old_node_ids each with entries."""
|
"""Multiple old_node_ids each with entries."""
|
||||||
@ -116,7 +81,10 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.assertEqual(self.mock_manager.register.call_count, 2)
|
result = self.manager.as_dict()
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.assertIn("NodeA", result)
|
||||||
|
self.assertIn("NodeC", result)
|
||||||
|
|
||||||
def test_multiple_alternatives_for_same_node(self):
|
def test_multiple_alternatives_for_same_node(self):
|
||||||
"""Multiple replacement options for the same old node."""
|
"""Multiple replacement options for the same old node."""
|
||||||
@ -127,7 +95,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.assertEqual(self.mock_manager.register.call_count, 2)
|
result = self.manager.as_dict()
|
||||||
|
self.assertEqual(len(result["OldNode"]), 2)
|
||||||
|
|
||||||
def test_null_mappings(self):
|
def test_null_mappings(self):
|
||||||
"""Null input/output mappings (trivial replacement)."""
|
"""Null input/output mappings (trivial replacement)."""
|
||||||
@ -140,9 +109,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
entry = self.manager.as_dict()["OldNode"][0]
|
||||||
self.assertIsNone(registered.input_mapping)
|
self.assertIsNone(entry["input_mapping"])
|
||||||
self.assertIsNone(registered.output_mapping)
|
self.assertIsNone(entry["output_mapping"])
|
||||||
|
|
||||||
def test_old_node_id_defaults_to_key(self):
|
def test_old_node_id_defaults_to_key(self):
|
||||||
"""If old_node_id is missing from entry, uses the dict key."""
|
"""If old_node_id is missing from entry, uses the dict key."""
|
||||||
@ -150,8 +119,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
"OldNode": [{"new_node_id": "NewNode"}]
|
"OldNode": [{"new_node_id": "NewNode"}]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
entry = self.manager.as_dict()["OldNode"][0]
|
||||||
self.assertEqual(registered.old_node_id, "OldNode")
|
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||||
|
|
||||||
def test_invalid_json_skips(self):
|
def test_invalid_json_skips(self):
|
||||||
"""Invalid JSON file — should warn and skip, not crash."""
|
"""Invalid JSON file — should warn and skip, not crash."""
|
||||||
@ -159,13 +128,13 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write("{invalid json")
|
f.write("{invalid json")
|
||||||
self._load()
|
self._load()
|
||||||
self.mock_manager.register.assert_not_called()
|
self.assertEqual(self.manager.as_dict(), {})
|
||||||
|
|
||||||
def test_non_object_json_skips(self):
|
def test_non_object_json_skips(self):
|
||||||
"""JSON array instead of object — should warn and skip."""
|
"""JSON array instead of object — should warn and skip."""
|
||||||
self._write_json([1, 2, 3])
|
self._write_json([1, 2, 3])
|
||||||
self._load()
|
self._load()
|
||||||
self.mock_manager.register.assert_not_called()
|
self.assertEqual(self.manager.as_dict(), {})
|
||||||
|
|
||||||
def test_non_list_value_skips(self):
|
def test_non_list_value_skips(self):
|
||||||
"""Value is not a list — should warn and skip that key."""
|
"""Value is not a list — should warn and skip that key."""
|
||||||
@ -174,7 +143,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.assertEqual(self.mock_manager.register.call_count, 1)
|
result = self.manager.as_dict()
|
||||||
|
self.assertNotIn("OldNode", result)
|
||||||
|
self.assertIn("GoodNode", result)
|
||||||
|
|
||||||
def test_with_old_widget_ids(self):
|
def test_with_old_widget_ids(self):
|
||||||
"""old_widget_ids are passed through."""
|
"""old_widget_ids are passed through."""
|
||||||
@ -186,8 +157,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
entry = self.manager.as_dict()["OldNode"][0]
|
||||||
self.assertEqual(registered.old_widget_ids, ["width", "height"])
|
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
|
||||||
|
|
||||||
def test_set_value_in_input_mapping(self):
|
def test_set_value_in_input_mapping(self):
|
||||||
"""input_mapping with set_value entries."""
|
"""input_mapping with set_value entries."""
|
||||||
@ -202,10 +173,8 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
entry = self.manager.as_dict()["OldNode"][0]
|
||||||
self.assertEqual(len(registered.input_mapping), 2)
|
self.assertEqual(len(entry["input_mapping"]), 2)
|
||||||
self.assertEqual(registered.input_mapping[0]["set_value"], "lanczos")
|
|
||||||
self.assertEqual(registered.input_mapping[1]["old_id"], "dimension")
|
|
||||||
|
|
||||||
def test_missing_new_node_id_skipped(self):
|
def test_missing_new_node_id_skipped(self):
|
||||||
"""Entry without new_node_id is skipped."""
|
"""Entry without new_node_id is skipped."""
|
||||||
@ -217,9 +186,9 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.assertEqual(self.mock_manager.register.call_count, 1)
|
result = self.manager.as_dict()
|
||||||
registered = self.mock_manager.register.call_args[0][0]
|
self.assertEqual(len(result["OldNode"]), 1)
|
||||||
self.assertEqual(registered.new_node_id, "ValidNew")
|
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
|
||||||
|
|
||||||
def test_non_dict_entry_skipped(self):
|
def test_non_dict_entry_skipped(self):
|
||||||
"""Non-dict entries in the list are silently skipped."""
|
"""Non-dict entries in the list are silently skipped."""
|
||||||
@ -230,7 +199,18 @@ class TestLoadNodeReplacementsJson(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
self._load()
|
self._load()
|
||||||
self.assertEqual(self.mock_manager.register.call_count, 1)
|
result = self.manager.as_dict()
|
||||||
|
self.assertEqual(len(result["OldNode"]), 1)
|
||||||
|
|
||||||
|
def test_has_replacement_after_load(self):
|
||||||
|
"""Manager reports has_replacement correctly after JSON load."""
|
||||||
|
self._write_json({
|
||||||
|
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
|
||||||
|
})
|
||||||
|
self.assertFalse(self.manager.has_replacement("OldNode"))
|
||||||
|
self._load()
|
||||||
|
self.assertTrue(self.manager.has_replacement("OldNode"))
|
||||||
|
self.assertFalse(self.manager.has_replacement("UnknownNode"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user