mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
290 lines
9.7 KiB
Python
290 lines
9.7 KiB
Python
"""Unit tests for cache-signature canonicalization hardening."""
|
|
|
|
import asyncio
|
|
import importlib
|
|
import sys
|
|
import types
|
|
|
|
import pytest
|
|
|
|
|
|
class _DummyNode:
|
|
"""Minimal node stub used to satisfy cache-signature class lookups."""
|
|
|
|
@staticmethod
|
|
def INPUT_TYPES():
|
|
"""Return a minimal empty input schema for unit tests."""
|
|
return {"required": {}}
|
|
|
|
|
|
class _FakeDynPrompt:
|
|
"""Small DynamicPrompt stand-in with only the methods these tests need."""
|
|
|
|
def __init__(self, nodes_by_id):
|
|
"""Store test nodes by id."""
|
|
self._nodes_by_id = nodes_by_id
|
|
|
|
def has_node(self, node_id):
|
|
"""Return whether the fake prompt contains the requested node."""
|
|
return node_id in self._nodes_by_id
|
|
|
|
def get_node(self, node_id):
|
|
"""Return the stored node payload for the requested id."""
|
|
return self._nodes_by_id[node_id]
|
|
|
|
|
|
class _FakeIsChangedCache:
|
|
"""Async stub for `is_changed` lookups used by cache-key generation."""
|
|
|
|
def __init__(self, values):
|
|
"""Store canned `is_changed` responses keyed by node id."""
|
|
self._values = values
|
|
|
|
async def get(self, node_id):
|
|
"""Return the canned `is_changed` value for a node."""
|
|
return self._values[node_id]
|
|
|
|
|
|
class _OpaqueValue:
|
|
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
|
|
|
|
|
@pytest.fixture
|
|
def caching_module(monkeypatch):
|
|
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
|
torch_module = types.ModuleType("torch")
|
|
psutil_module = types.ModuleType("psutil")
|
|
nodes_module = types.ModuleType("nodes")
|
|
nodes_module.NODE_CLASS_MAPPINGS = {}
|
|
graph_module = types.ModuleType("comfy_execution.graph")
|
|
|
|
class DynamicPrompt:
|
|
"""Placeholder graph type so the caching module can import cleanly."""
|
|
|
|
pass
|
|
|
|
graph_module.DynamicPrompt = DynamicPrompt
|
|
|
|
monkeypatch.setitem(sys.modules, "torch", torch_module)
|
|
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
|
|
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
|
|
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
|
|
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
|
|
|
|
module = importlib.import_module("comfy_execution.caching")
|
|
module = importlib.reload(module)
|
|
return module, nodes_module
|
|
|
|
|
|
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
|
|
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
|
|
caching, _ = caching_module
|
|
shared = [{"value": 1}, {"value": 2}]
|
|
|
|
signature = caching._signature_to_hashable([shared, shared])
|
|
|
|
assert signature[0] == "list"
|
|
assert signature[1][0] == signature[1][1]
|
|
assert signature[1][0][0] == "list"
|
|
assert signature[1][0][1][0] == ("dict", (("value", 1),))
|
|
assert signature[1][0][1][1] == ("dict", (("value", 2),))
|
|
|
|
|
|
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
|
|
"""Opaque values should collapse the full signature to Unhashable immediately."""
|
|
caching, _ = caching_module
|
|
|
|
signature = caching._signature_to_hashable(["safe", object()])
|
|
|
|
assert isinstance(signature, caching.Unhashable)
|
|
|
|
|
|
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
|
|
"""Once canonicalization fails, later recursive descent should stop immediately."""
|
|
caching, _ = caching_module
|
|
original = caching._signature_to_hashable_impl
|
|
marker = object()
|
|
marker_seen = False
|
|
|
|
def tracking_canonicalize(obj, *args, **kwargs):
|
|
"""Track whether recursion reaches the nested marker after failure."""
|
|
nonlocal marker_seen
|
|
if obj is marker:
|
|
marker_seen = True
|
|
return original(obj, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
|
|
|
|
signature = caching._signature_to_hashable([object(), [marker]])
|
|
|
|
assert isinstance(signature, caching.Unhashable)
|
|
assert marker_seen is False
|
|
|
|
|
|
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
|
|
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
|
|
caching, _ = caching_module
|
|
original = caching._signature_to_hashable_impl
|
|
marker = ("marker",)
|
|
values = [marker, 2]
|
|
|
|
def mutating_canonicalize(obj, *args, **kwargs):
|
|
"""Mutate the live list during recursion to verify snapshot-based traversal."""
|
|
if obj is marker:
|
|
values[1] = 3
|
|
return original(obj, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
|
|
|
signature = caching._signature_to_hashable(values)
|
|
|
|
assert signature == ("list", (("tuple", ("marker",)), 2))
|
|
assert values[1] == 3
|
|
|
|
|
|
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
|
|
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
|
|
caching, _ = caching_module
|
|
original = caching._signature_to_hashable_impl
|
|
marker = ("marker",)
|
|
values = {"first": marker, "second": 2}
|
|
|
|
def mutating_canonicalize(obj, *args, **kwargs):
|
|
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
|
|
if obj is marker:
|
|
values["second"] = 3
|
|
return original(obj, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
|
|
|
signature = caching._signature_to_hashable(values)
|
|
|
|
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
|
|
assert values["second"] == 3
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"container_factory",
|
|
[
|
|
lambda marker: [marker],
|
|
lambda marker: (marker,),
|
|
lambda marker: {marker},
|
|
lambda marker: frozenset({marker}),
|
|
lambda marker: {marker: "value"},
|
|
],
|
|
)
|
|
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
|
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
|
|
caching, _ = caching_module
|
|
original = caching._signature_to_hashable_impl
|
|
marker = object()
|
|
|
|
def raising_canonicalize(obj, *args, **kwargs):
|
|
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
|
|
if obj is marker:
|
|
raise RuntimeError("container changed during iteration")
|
|
return original(obj, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
|
|
|
|
signature = caching._signature_to_hashable(container_factory(marker))
|
|
|
|
assert isinstance(signature, caching.Unhashable)
|
|
|
|
|
|
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
|
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
|
|
caching, _ = caching_module
|
|
shared = [{"value": 1}, {"value": 2}]
|
|
|
|
sanitized = [shared, shared]
|
|
hashable = caching.to_hashable(sanitized)
|
|
|
|
assert hashable[0] == "list"
|
|
assert hashable[1][0] == hashable[1][1]
|
|
assert hashable[1][0][0] == "list"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"container_factory",
|
|
[
|
|
set,
|
|
frozenset,
|
|
],
|
|
)
|
|
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
|
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
|
|
caching, _ = caching_module
|
|
|
|
def raising_sort_key(obj, *args, **kwargs):
|
|
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
|
|
raise RuntimeError("container changed during iteration")
|
|
|
|
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
|
|
|
|
hashable = caching.to_hashable(container_factory({"value"}))
|
|
|
|
assert isinstance(hashable, caching.Unhashable)
|
|
|
|
|
|
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module):
|
|
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
|
caching, _ = caching_module
|
|
ambiguous = {
|
|
_OpaqueValue(): _OpaqueValue(),
|
|
_OpaqueValue(): _OpaqueValue(),
|
|
}
|
|
|
|
sanitized = caching._signature_to_hashable(ambiguous)
|
|
|
|
assert isinstance(sanitized, caching.Unhashable)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"container_factory",
|
|
[
|
|
set,
|
|
frozenset,
|
|
],
|
|
)
|
|
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
|
|
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
|
|
caching, _ = caching_module
|
|
container = container_factory({_OpaqueValue(), _OpaqueValue()})
|
|
|
|
hashable = caching.to_hashable(container)
|
|
|
|
assert isinstance(hashable, caching.Unhashable)
|
|
|
|
|
|
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
|
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
|
caching, nodes_module = caching_module
|
|
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
|
monkeypatch.setattr(
|
|
caching,
|
|
"to_hashable",
|
|
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
|
)
|
|
|
|
is_changed_value = []
|
|
is_changed_value.append(is_changed_value)
|
|
|
|
dynprompt = _FakeDynPrompt(
|
|
{
|
|
"node": {
|
|
"class_type": "UnitTestNode",
|
|
"inputs": {"value": 5},
|
|
}
|
|
}
|
|
)
|
|
key_set = caching.CacheKeySetInputSignature(
|
|
dynprompt,
|
|
["node"],
|
|
_FakeIsChangedCache({"node": is_changed_value}),
|
|
)
|
|
|
|
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
|
|
|
assert isinstance(signature, caching.Unhashable)
|