mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
Address review feedback from Kosinkadink: 1. Move JSON loading logic from nodes.py into NodeReplaceManager as load_from_json() method for better encapsulation and testability 2. Tests now exercise the real NodeReplaceManager (no duplicated logic) 3. Defer `import nodes` in apply_replacements to avoid torch at import 4. nodes.py call site simplified to one line: PromptServer.instance.node_replace_manager.load_from_json(...)
218 lines
7.9 KiB
Python
218 lines
7.9 KiB
Python
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
|
|
node_replacements.json from custom node directories."""
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from app.node_replace_manager import NodeReplaceManager
|
|
|
|
|
|
class SimpleNodeReplace:
|
|
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
|
|
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
|
input_mapping=None, output_mapping=None):
|
|
self.new_node_id = new_node_id
|
|
self.old_node_id = old_node_id
|
|
self.old_widget_ids = old_widget_ids
|
|
self.input_mapping = input_mapping
|
|
self.output_mapping = output_mapping
|
|
|
|
def as_dict(self):
|
|
return {
|
|
"new_node_id": self.new_node_id,
|
|
"old_node_id": self.old_node_id,
|
|
"old_widget_ids": self.old_widget_ids,
|
|
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
|
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
|
}
|
|
|
|
|
|
class TestLoadFromJson(unittest.TestCase):
|
|
"""Test auto-registration of node_replacements.json from custom node directories."""
|
|
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.manager = NodeReplaceManager()
|
|
|
|
def _write_json(self, data):
|
|
path = os.path.join(self.tmpdir, "node_replacements.json")
|
|
with open(path, "w") as f:
|
|
json.dump(data, f)
|
|
|
|
def _load(self):
|
|
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
|
|
|
|
def test_no_file_does_nothing(self):
|
|
"""No node_replacements.json — should silently do nothing."""
|
|
self._load()
|
|
self.assertEqual(self.manager.as_dict(), {})
|
|
|
|
def test_empty_object(self):
|
|
"""Empty {} — should do nothing."""
|
|
self._write_json({})
|
|
self._load()
|
|
self.assertEqual(self.manager.as_dict(), {})
|
|
|
|
def test_single_replacement(self):
|
|
"""Single replacement entry registers correctly."""
|
|
self._write_json({
|
|
"OldNode": [{
|
|
"new_node_id": "NewNode",
|
|
"old_node_id": "OldNode",
|
|
"input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}],
|
|
"output_mapping": [{"new_idx": 0, "old_idx": 0}],
|
|
}]
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertIn("OldNode", result)
|
|
self.assertEqual(len(result["OldNode"]), 1)
|
|
entry = result["OldNode"][0]
|
|
self.assertEqual(entry["new_node_id"], "NewNode")
|
|
self.assertEqual(entry["old_node_id"], "OldNode")
|
|
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
|
|
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
|
|
|
|
def test_multiple_replacements(self):
|
|
"""Multiple old_node_ids each with entries."""
|
|
self._write_json({
|
|
"NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}],
|
|
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertEqual(len(result), 2)
|
|
self.assertIn("NodeA", result)
|
|
self.assertIn("NodeC", result)
|
|
|
|
def test_multiple_alternatives_for_same_node(self):
|
|
"""Multiple replacement options for the same old node."""
|
|
self._write_json({
|
|
"OldNode": [
|
|
{"new_node_id": "AltA", "old_node_id": "OldNode"},
|
|
{"new_node_id": "AltB", "old_node_id": "OldNode"},
|
|
]
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertEqual(len(result["OldNode"]), 2)
|
|
|
|
def test_null_mappings(self):
|
|
"""Null input/output mappings (trivial replacement)."""
|
|
self._write_json({
|
|
"OldNode": [{
|
|
"new_node_id": "NewNode",
|
|
"old_node_id": "OldNode",
|
|
"input_mapping": None,
|
|
"output_mapping": None,
|
|
}]
|
|
})
|
|
self._load()
|
|
entry = self.manager.as_dict()["OldNode"][0]
|
|
self.assertIsNone(entry["input_mapping"])
|
|
self.assertIsNone(entry["output_mapping"])
|
|
|
|
def test_old_node_id_defaults_to_key(self):
|
|
"""If old_node_id is missing from entry, uses the dict key."""
|
|
self._write_json({
|
|
"OldNode": [{"new_node_id": "NewNode"}]
|
|
})
|
|
self._load()
|
|
entry = self.manager.as_dict()["OldNode"][0]
|
|
self.assertEqual(entry["old_node_id"], "OldNode")
|
|
|
|
def test_invalid_json_skips(self):
|
|
"""Invalid JSON file — should warn and skip, not crash."""
|
|
path = os.path.join(self.tmpdir, "node_replacements.json")
|
|
with open(path, "w") as f:
|
|
f.write("{invalid json")
|
|
self._load()
|
|
self.assertEqual(self.manager.as_dict(), {})
|
|
|
|
def test_non_object_json_skips(self):
|
|
"""JSON array instead of object — should warn and skip."""
|
|
self._write_json([1, 2, 3])
|
|
self._load()
|
|
self.assertEqual(self.manager.as_dict(), {})
|
|
|
|
def test_non_list_value_skips(self):
|
|
"""Value is not a list — should warn and skip that key."""
|
|
self._write_json({
|
|
"OldNode": "not a list",
|
|
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertNotIn("OldNode", result)
|
|
self.assertIn("GoodNode", result)
|
|
|
|
def test_with_old_widget_ids(self):
|
|
"""old_widget_ids are passed through."""
|
|
self._write_json({
|
|
"OldNode": [{
|
|
"new_node_id": "NewNode",
|
|
"old_node_id": "OldNode",
|
|
"old_widget_ids": ["width", "height"],
|
|
}]
|
|
})
|
|
self._load()
|
|
entry = self.manager.as_dict()["OldNode"][0]
|
|
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
|
|
|
|
def test_set_value_in_input_mapping(self):
|
|
"""input_mapping with set_value entries."""
|
|
self._write_json({
|
|
"OldNode": [{
|
|
"new_node_id": "NewNode",
|
|
"old_node_id": "OldNode",
|
|
"input_mapping": [
|
|
{"new_id": "method", "set_value": "lanczos"},
|
|
{"new_id": "size", "old_id": "dimension"},
|
|
],
|
|
}]
|
|
})
|
|
self._load()
|
|
entry = self.manager.as_dict()["OldNode"][0]
|
|
self.assertEqual(len(entry["input_mapping"]), 2)
|
|
|
|
def test_missing_new_node_id_skipped(self):
|
|
"""Entry without new_node_id is skipped."""
|
|
self._write_json({
|
|
"OldNode": [
|
|
{"old_node_id": "OldNode"},
|
|
{"new_node_id": "", "old_node_id": "OldNode"},
|
|
{"new_node_id": "ValidNew", "old_node_id": "OldNode"},
|
|
]
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertEqual(len(result["OldNode"]), 1)
|
|
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
|
|
|
|
def test_non_dict_entry_skipped(self):
|
|
"""Non-dict entries in the list are silently skipped."""
|
|
self._write_json({
|
|
"OldNode": [
|
|
"not a dict",
|
|
{"new_node_id": "NewNode", "old_node_id": "OldNode"},
|
|
]
|
|
})
|
|
self._load()
|
|
result = self.manager.as_dict()
|
|
self.assertEqual(len(result["OldNode"]), 1)
|
|
|
|
def test_has_replacement_after_load(self):
|
|
"""Manager reports has_replacement correctly after JSON load."""
|
|
self._write_json({
|
|
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
|
|
})
|
|
self.assertFalse(self.manager.has_replacement("OldNode"))
|
|
self._load()
|
|
self.assertTrue(self.manager.has_replacement("OldNode"))
|
|
self.assertFalse(self.manager.has_replacement("UnknownNode"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|