diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py index 1e597465a..16890dfc4 100644 --- a/comfy_execution/cache_provider.py +++ b/comfy_execution/cache_provider.py @@ -2,7 +2,6 @@ from typing import Any, Optional, Tuple, List import hashlib import json import logging -import math import threading # Public types — source of truth is comfy_api.latest._caching @@ -57,8 +56,9 @@ def _clear_cache_providers() -> None: def _canonicalize(obj: Any) -> Any: # Convert to canonical JSON-serializable form with deterministic ordering. # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. if isinstance(obj, frozenset): - # Sort frozenset items for deterministic ordering return ("__frozenset__", sorted( [_canonicalize(item) for item in obj], key=lambda x: json.dumps(x, sort_keys=True) @@ -78,12 +78,8 @@ def _canonicalize(obj: Any) -> Any: return obj elif isinstance(obj, bytes): return ("__bytes__", obj.hex()) - elif hasattr(obj, 'value'): - # Handle Unhashable class from ComfyUI - return ("__unhashable__", _canonicalize(getattr(obj, 'value', None))) else: - # For other types, use repr as fallback - return ("__repr__", repr(obj)) + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") def _serialize_cache_key(cache_key: Any) -> Optional[str]: @@ -98,25 +94,20 @@ def _serialize_cache_key(cache_key: Any) -> Optional[str]: return None -def _contains_nan(obj: Any) -> bool: - # NaN != NaN so local cache never hits, but serialized NaN would match. - # Skip external caching for keys containing NaN. - if isinstance(obj, float): - try: - return math.isnan(obj) - except (TypeError, ValueError): - return False - if hasattr(obj, 'value'): # Unhashable class - val = getattr(obj, 'value', None) - if isinstance(val, float): - try: - return math.isnan(val) - except (TypeError, ValueError): - return False +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True if isinstance(obj, (frozenset, tuple, list, set)): - return any(_contains_nan(item) for item in obj) + return any(_contains_self_unequal(item) for item in obj) if isinstance(obj, dict): - return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items()) + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) return False diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 3b987846b..a479c6522 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -238,7 +238,7 @@ class BasicCache: async def _notify_providers_store(self, node_id, cache_key, value): from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, - CacheValue, _contains_nan, _logger + CacheValue, _contains_self_unequal, _logger ) if self._is_subcache: @@ -247,7 +247,7 @@ class BasicCache: return if not self._is_external_cacheable_value(value): return - if _contains_nan(cache_key): + if _contains_self_unequal(cache_key): return context = self._build_context(node_id, cache_key) @@ -273,14 +273,14 @@ class BasicCache: async def _check_providers_lookup(self, node_id, cache_key): from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, - CacheValue, _contains_nan, _logger + CacheValue, _contains_self_unequal, _logger ) if self._is_subcache: return None if not _has_cache_providers(): return None - if _contains_nan(cache_key): + if _contains_self_unequal(cache_key): return None context = self._build_context(node_id, cache_key) diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py index a11673610..ccbad88ce 100644 --- a/tests-unit/execution_test/test_cache_provider.py +++ b/tests-unit/execution_test/test_cache_provider.py @@ -20,7 +20,7 @@ from comfy_execution.cache_provider import ( _has_cache_providers, _clear_cache_providers, _serialize_cache_key, - _contains_nan, + _contains_self_unequal, _estimate_value_size, _canonicalize, ) @@ -108,6 +108,21 @@ class TestCanonicalize: assert result1 == result2 assert result1[0] == "__set__" + def test_unknown_type_raises(self): + """Unknown types should raise ValueError (fail-closed).""" + class CustomObj: + pass + with pytest.raises(ValueError): + _canonicalize(CustomObj()) + + def test_object_with_value_attr_raises(self): + """Objects with .value attribute (Unhashable-like) should raise ValueError.""" + class FakeUnhashable: + def __init__(self): + self.value = float('nan') + with pytest.raises(ValueError): + _canonicalize(FakeUnhashable()) + class TestSerializeCacheKey: """Test _serialize_cache_key for deterministic hashing.""" @@ -162,7 +177,6 @@ class TestSerializeCacheKey: def test_dict_in_cache_key(self): """Dicts passed directly to _serialize_cache_key should work.""" - # This tests the _canonicalize function's ability to handle dicts key = {"node_1": {"input": "value"}, "node_2": 42} hash1 = _serialize_cache_key(key) @@ -172,55 +186,75 @@ class TestSerializeCacheKey: assert isinstance(hash1, str) assert len(hash1) == 64 + def test_unknown_type_returns_none(self): + """Non-cacheable types should return None (fail-closed).""" + class CustomObj: + pass + assert _serialize_cache_key(CustomObj()) is None -class TestContainsNan: - """Test _contains_nan utility function.""" + +class TestContainsSelfUnequal: + """Test _contains_self_unequal utility function.""" def test_nan_float_detected(self): - """NaN floats should be detected.""" - assert _contains_nan(float('nan')) is True + """NaN floats should be detected (not equal to itself).""" + assert _contains_self_unequal(float('nan')) is True - def test_regular_float_not_nan(self): - """Regular floats should not be detected as NaN.""" - assert _contains_nan(3.14) is False - assert _contains_nan(0.0) is False - assert _contains_nan(-1.5) is False + def test_regular_float_not_detected(self): + """Regular floats are equal to themselves.""" + assert _contains_self_unequal(3.14) is False + assert _contains_self_unequal(0.0) is False + assert _contains_self_unequal(-1.5) is False - def test_infinity_not_nan(self): - """Infinity is not NaN.""" - assert _contains_nan(float('inf')) is False - assert _contains_nan(float('-inf')) is False + def test_infinity_not_detected(self): + """Infinity is equal to itself.""" + assert _contains_self_unequal(float('inf')) is False + assert _contains_self_unequal(float('-inf')) is False def test_nan_in_list(self): """NaN in list should be detected.""" - assert _contains_nan([1, 2, float('nan'), 4]) is True - assert _contains_nan([1, 2, 3, 4]) is False + assert _contains_self_unequal([1, 2, float('nan'), 4]) is True + assert _contains_self_unequal([1, 2, 3, 4]) is False def test_nan_in_tuple(self): """NaN in tuple should be detected.""" - assert _contains_nan((1, float('nan'))) is True - assert _contains_nan((1, 2, 3)) is False + assert _contains_self_unequal((1, float('nan'))) is True + assert _contains_self_unequal((1, 2, 3)) is False def test_nan_in_frozenset(self): """NaN in frozenset should be detected.""" - assert _contains_nan(frozenset([1, float('nan')])) is True - assert _contains_nan(frozenset([1, 2, 3])) is False + assert _contains_self_unequal(frozenset([1, float('nan')])) is True + assert _contains_self_unequal(frozenset([1, 2, 3])) is False def test_nan_in_dict_value(self): """NaN in dict value should be detected.""" - assert _contains_nan({"key": float('nan')}) is True - assert _contains_nan({"key": 42}) is False + assert _contains_self_unequal({"key": float('nan')}) is True + assert _contains_self_unequal({"key": 42}) is False def test_nan_in_nested_structure(self): """NaN in deeply nested structure should be detected.""" nested = {"level1": [{"level2": (1, 2, float('nan'))}]} - assert _contains_nan(nested) is True + assert _contains_self_unequal(nested) is True def test_non_numeric_types(self): - """Non-numeric types should not be NaN.""" - assert _contains_nan("string") is False - assert _contains_nan(None) is False - assert _contains_nan(True) is False + """Non-numeric types should not be self-unequal.""" + assert _contains_self_unequal("string") is False + assert _contains_self_unequal(None) is False + assert _contains_self_unequal(True) is False + + def test_object_with_nan_value_attribute(self): + """Objects wrapping NaN in .value should be detected.""" + class NanWrapper: + def __init__(self): + self.value = float('nan') + assert _contains_self_unequal(NanWrapper()) is True + + def test_custom_self_unequal_object(self): + """Custom objects where not (x == x) should be detected.""" + class NeverEqual: + def __eq__(self, other): + return False + assert _contains_self_unequal(NeverEqual()) is True class TestEstimateValueSize: