ComfyUI/tests-unit/execution_test/caching_test.py

177 lines
5.6 KiB
Python

"""Unit tests for cache-signature sanitization and hash conversion hardening."""
import asyncio
import importlib
import sys
import types
import pytest
class _DummyNode:
"""Minimal node stub used to satisfy cache-signature class lookups."""
@staticmethod
def INPUT_TYPES():
"""Return a minimal empty input schema for unit tests."""
return {"required": {}}
class _FakeDynPrompt:
"""Small DynamicPrompt stand-in with only the methods these tests need."""
def __init__(self, nodes_by_id):
"""Store test nodes by id."""
self._nodes_by_id = nodes_by_id
def has_node(self, node_id):
"""Return whether the fake prompt contains the requested node."""
return node_id in self._nodes_by_id
def get_node(self, node_id):
"""Return the stored node payload for the requested id."""
return self._nodes_by_id[node_id]
class _FakeIsChangedCache:
"""Async stub for `is_changed` lookups used by cache-key generation."""
def __init__(self, values):
"""Store canned `is_changed` responses keyed by node id."""
self._values = values
async def get(self, node_id):
"""Return the canned `is_changed` value for a node."""
return self._values[node_id]
class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
def _contains_unhashable(value, unhashable_type):
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
if isinstance(value, unhashable_type):
return True
value_type = type(value)
if value_type is dict:
return any(
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
for key, item in value.items()
)
if value_type in (list, tuple, set, frozenset):
return any(_contains_unhashable(item, unhashable_type) for item in value)
return False
@pytest.fixture
def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
torch_module = types.ModuleType("torch")
psutil_module = types.ModuleType("psutil")
nodes_module = types.ModuleType("nodes")
nodes_module.NODE_CLASS_MAPPINGS = {}
graph_module = types.ModuleType("comfy_execution.graph")
class DynamicPrompt:
"""Placeholder graph type so the caching module can import cleanly."""
pass
graph_module.DynamicPrompt = DynamicPrompt
monkeypatch.setitem(sys.modules, "torch", torch_module)
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
module = importlib.import_module("comfy_execution.caching")
module = importlib.reload(module)
return module, nodes_module
def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_module):
"""Shared built-in substructures should sanitize without collapsing to Unhashable."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = caching._sanitize_signature_input([shared, shared])
assert isinstance(sanitized, list)
assert sanitized[0] == sanitized[1]
assert sanitized[0][0]["value"] == 1
assert sanitized[0][1]["value"] == 2
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
"""Repeated sanitized content should hash stably for shared substructures."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = caching._sanitize_signature_input([shared, shared])
hashable = caching.to_hashable(sanitized)
assert hashable[0] == "list"
assert hashable[1][0] == hashable[1][1]
assert hashable[1][0][0] == "list"
def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module):
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
caching, _ = caching_module
ambiguous = {
_OpaqueValue(): _OpaqueValue(),
_OpaqueValue(): _OpaqueValue(),
}
sanitized = caching._sanitize_signature_input(ambiguous)
assert isinstance(sanitized, caching.Unhashable)
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
caching, _ = caching_module
container = container_factory({_OpaqueValue(), _OpaqueValue()})
hashable = caching.to_hashable(container)
assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
is_changed_value = []
is_changed_value.append(is_changed_value)
dynprompt = _FakeDynPrompt(
{
"node": {
"class_type": "UnitTestNode",
"inputs": {"value": 5},
}
}
)
key_set = caching.CacheKeySetInputSignature(
dynprompt,
["node"],
_FakeIsChangedCache({"node": is_changed_value}),
)
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert signature[0] == "list"
assert _contains_unhashable(signature, caching.Unhashable)