mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 10:47:32 +08:00
fix: generalize self-inequality check, fail-closed canonicalization
Address review feedback from guill: - Rename _contains_nan to _contains_self_unequal, use not (x == x) instead of math.isnan to catch any self-unequal value - Remove Unhashable and repr() fallbacks from _canonicalize; raise ValueError for unknown types so _serialize_cache_key returns None and external caching is skipped (fail-closed) - Update tests for renamed function and new fail-closed behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8ed3386d3b
commit
15a23ad5f6
@ -2,7 +2,6 @@ from typing import Any, Optional, Tuple, List
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
# Public types — source of truth is comfy_api.latest._caching
|
# Public types — source of truth is comfy_api.latest._caching
|
||||||
@ -57,8 +56,9 @@ def _clear_cache_providers() -> None:
|
|||||||
def _canonicalize(obj: Any) -> Any:
|
def _canonicalize(obj: Any) -> Any:
|
||||||
# Convert to canonical JSON-serializable form with deterministic ordering.
|
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||||
# Frozensets have non-deterministic iteration order between Python sessions.
|
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||||
|
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
||||||
|
# _serialize_cache_key returns None and external caching is skipped.
|
||||||
if isinstance(obj, frozenset):
|
if isinstance(obj, frozenset):
|
||||||
# Sort frozenset items for deterministic ordering
|
|
||||||
return ("__frozenset__", sorted(
|
return ("__frozenset__", sorted(
|
||||||
[_canonicalize(item) for item in obj],
|
[_canonicalize(item) for item in obj],
|
||||||
key=lambda x: json.dumps(x, sort_keys=True)
|
key=lambda x: json.dumps(x, sort_keys=True)
|
||||||
@ -78,12 +78,8 @@ def _canonicalize(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, bytes):
|
elif isinstance(obj, bytes):
|
||||||
return ("__bytes__", obj.hex())
|
return ("__bytes__", obj.hex())
|
||||||
elif hasattr(obj, 'value'):
|
|
||||||
# Handle Unhashable class from ComfyUI
|
|
||||||
return ("__unhashable__", _canonicalize(getattr(obj, 'value', None)))
|
|
||||||
else:
|
else:
|
||||||
# For other types, use repr as fallback
|
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
||||||
return ("__repr__", repr(obj))
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||||
@ -98,25 +94,20 @@ def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _contains_nan(obj: Any) -> bool:
|
def _contains_self_unequal(obj: Any) -> bool:
|
||||||
# NaN != NaN so local cache never hits, but serialized NaN would match.
|
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
||||||
# Skip external caching for keys containing NaN.
|
# never hit locally, but serialized form would match externally. Skip these.
|
||||||
if isinstance(obj, float):
|
try:
|
||||||
try:
|
if not (obj == obj):
|
||||||
return math.isnan(obj)
|
return True
|
||||||
except (TypeError, ValueError):
|
except Exception:
|
||||||
return False
|
return True
|
||||||
if hasattr(obj, 'value'): # Unhashable class
|
|
||||||
val = getattr(obj, 'value', None)
|
|
||||||
if isinstance(val, float):
|
|
||||||
try:
|
|
||||||
return math.isnan(val)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return False
|
|
||||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||||
return any(_contains_nan(item) for item in obj)
|
return any(_contains_self_unequal(item) for item in obj)
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items())
|
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
||||||
|
if hasattr(obj, 'value'):
|
||||||
|
return _contains_self_unequal(obj.value)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -238,7 +238,7 @@ class BasicCache:
|
|||||||
async def _notify_providers_store(self, node_id, cache_key, value):
|
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
_has_cache_providers, _get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheValue, _contains_nan, _logger
|
CacheValue, _contains_self_unequal, _logger
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_subcache:
|
if self._is_subcache:
|
||||||
@ -247,7 +247,7 @@ class BasicCache:
|
|||||||
return
|
return
|
||||||
if not self._is_external_cacheable_value(value):
|
if not self._is_external_cacheable_value(value):
|
||||||
return
|
return
|
||||||
if _contains_nan(cache_key):
|
if _contains_self_unequal(cache_key):
|
||||||
return
|
return
|
||||||
|
|
||||||
context = self._build_context(node_id, cache_key)
|
context = self._build_context(node_id, cache_key)
|
||||||
@ -273,14 +273,14 @@ class BasicCache:
|
|||||||
async def _check_providers_lookup(self, node_id, cache_key):
|
async def _check_providers_lookup(self, node_id, cache_key):
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
_has_cache_providers, _get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheValue, _contains_nan, _logger
|
CacheValue, _contains_self_unequal, _logger
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_subcache:
|
if self._is_subcache:
|
||||||
return None
|
return None
|
||||||
if not _has_cache_providers():
|
if not _has_cache_providers():
|
||||||
return None
|
return None
|
||||||
if _contains_nan(cache_key):
|
if _contains_self_unequal(cache_key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
context = self._build_context(node_id, cache_key)
|
context = self._build_context(node_id, cache_key)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from comfy_execution.cache_provider import (
|
|||||||
_has_cache_providers,
|
_has_cache_providers,
|
||||||
_clear_cache_providers,
|
_clear_cache_providers,
|
||||||
_serialize_cache_key,
|
_serialize_cache_key,
|
||||||
_contains_nan,
|
_contains_self_unequal,
|
||||||
_estimate_value_size,
|
_estimate_value_size,
|
||||||
_canonicalize,
|
_canonicalize,
|
||||||
)
|
)
|
||||||
@ -108,6 +108,21 @@ class TestCanonicalize:
|
|||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
assert result1[0] == "__set__"
|
assert result1[0] == "__set__"
|
||||||
|
|
||||||
|
def test_unknown_type_raises(self):
|
||||||
|
"""Unknown types should raise ValueError (fail-closed)."""
|
||||||
|
class CustomObj:
|
||||||
|
pass
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_canonicalize(CustomObj())
|
||||||
|
|
||||||
|
def test_object_with_value_attr_raises(self):
|
||||||
|
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
|
||||||
|
class FakeUnhashable:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float('nan')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_canonicalize(FakeUnhashable())
|
||||||
|
|
||||||
|
|
||||||
class TestSerializeCacheKey:
|
class TestSerializeCacheKey:
|
||||||
"""Test _serialize_cache_key for deterministic hashing."""
|
"""Test _serialize_cache_key for deterministic hashing."""
|
||||||
@ -162,7 +177,6 @@ class TestSerializeCacheKey:
|
|||||||
|
|
||||||
def test_dict_in_cache_key(self):
|
def test_dict_in_cache_key(self):
|
||||||
"""Dicts passed directly to _serialize_cache_key should work."""
|
"""Dicts passed directly to _serialize_cache_key should work."""
|
||||||
# This tests the _canonicalize function's ability to handle dicts
|
|
||||||
key = {"node_1": {"input": "value"}, "node_2": 42}
|
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||||
|
|
||||||
hash1 = _serialize_cache_key(key)
|
hash1 = _serialize_cache_key(key)
|
||||||
@ -172,55 +186,75 @@ class TestSerializeCacheKey:
|
|||||||
assert isinstance(hash1, str)
|
assert isinstance(hash1, str)
|
||||||
assert len(hash1) == 64
|
assert len(hash1) == 64
|
||||||
|
|
||||||
|
def test_unknown_type_returns_none(self):
|
||||||
|
"""Non-cacheable types should return None (fail-closed)."""
|
||||||
|
class CustomObj:
|
||||||
|
pass
|
||||||
|
assert _serialize_cache_key(CustomObj()) is None
|
||||||
|
|
||||||
class TestContainsNan:
|
|
||||||
"""Test _contains_nan utility function."""
|
class TestContainsSelfUnequal:
|
||||||
|
"""Test _contains_self_unequal utility function."""
|
||||||
|
|
||||||
def test_nan_float_detected(self):
|
def test_nan_float_detected(self):
|
||||||
"""NaN floats should be detected."""
|
"""NaN floats should be detected (not equal to itself)."""
|
||||||
assert _contains_nan(float('nan')) is True
|
assert _contains_self_unequal(float('nan')) is True
|
||||||
|
|
||||||
def test_regular_float_not_nan(self):
|
def test_regular_float_not_detected(self):
|
||||||
"""Regular floats should not be detected as NaN."""
|
"""Regular floats are equal to themselves."""
|
||||||
assert _contains_nan(3.14) is False
|
assert _contains_self_unequal(3.14) is False
|
||||||
assert _contains_nan(0.0) is False
|
assert _contains_self_unequal(0.0) is False
|
||||||
assert _contains_nan(-1.5) is False
|
assert _contains_self_unequal(-1.5) is False
|
||||||
|
|
||||||
def test_infinity_not_nan(self):
|
def test_infinity_not_detected(self):
|
||||||
"""Infinity is not NaN."""
|
"""Infinity is equal to itself."""
|
||||||
assert _contains_nan(float('inf')) is False
|
assert _contains_self_unequal(float('inf')) is False
|
||||||
assert _contains_nan(float('-inf')) is False
|
assert _contains_self_unequal(float('-inf')) is False
|
||||||
|
|
||||||
def test_nan_in_list(self):
|
def test_nan_in_list(self):
|
||||||
"""NaN in list should be detected."""
|
"""NaN in list should be detected."""
|
||||||
assert _contains_nan([1, 2, float('nan'), 4]) is True
|
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
|
||||||
assert _contains_nan([1, 2, 3, 4]) is False
|
assert _contains_self_unequal([1, 2, 3, 4]) is False
|
||||||
|
|
||||||
def test_nan_in_tuple(self):
|
def test_nan_in_tuple(self):
|
||||||
"""NaN in tuple should be detected."""
|
"""NaN in tuple should be detected."""
|
||||||
assert _contains_nan((1, float('nan'))) is True
|
assert _contains_self_unequal((1, float('nan'))) is True
|
||||||
assert _contains_nan((1, 2, 3)) is False
|
assert _contains_self_unequal((1, 2, 3)) is False
|
||||||
|
|
||||||
def test_nan_in_frozenset(self):
|
def test_nan_in_frozenset(self):
|
||||||
"""NaN in frozenset should be detected."""
|
"""NaN in frozenset should be detected."""
|
||||||
assert _contains_nan(frozenset([1, float('nan')])) is True
|
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
|
||||||
assert _contains_nan(frozenset([1, 2, 3])) is False
|
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
|
||||||
|
|
||||||
def test_nan_in_dict_value(self):
|
def test_nan_in_dict_value(self):
|
||||||
"""NaN in dict value should be detected."""
|
"""NaN in dict value should be detected."""
|
||||||
assert _contains_nan({"key": float('nan')}) is True
|
assert _contains_self_unequal({"key": float('nan')}) is True
|
||||||
assert _contains_nan({"key": 42}) is False
|
assert _contains_self_unequal({"key": 42}) is False
|
||||||
|
|
||||||
def test_nan_in_nested_structure(self):
|
def test_nan_in_nested_structure(self):
|
||||||
"""NaN in deeply nested structure should be detected."""
|
"""NaN in deeply nested structure should be detected."""
|
||||||
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||||
assert _contains_nan(nested) is True
|
assert _contains_self_unequal(nested) is True
|
||||||
|
|
||||||
def test_non_numeric_types(self):
|
def test_non_numeric_types(self):
|
||||||
"""Non-numeric types should not be NaN."""
|
"""Non-numeric types should not be self-unequal."""
|
||||||
assert _contains_nan("string") is False
|
assert _contains_self_unequal("string") is False
|
||||||
assert _contains_nan(None) is False
|
assert _contains_self_unequal(None) is False
|
||||||
assert _contains_nan(True) is False
|
assert _contains_self_unequal(True) is False
|
||||||
|
|
||||||
|
def test_object_with_nan_value_attribute(self):
|
||||||
|
"""Objects wrapping NaN in .value should be detected."""
|
||||||
|
class NanWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float('nan')
|
||||||
|
assert _contains_self_unequal(NanWrapper()) is True
|
||||||
|
|
||||||
|
def test_custom_self_unequal_object(self):
|
||||||
|
"""Custom objects where not (x == x) should be detected."""
|
||||||
|
class NeverEqual:
|
||||||
|
def __eq__(self, other):
|
||||||
|
return False
|
||||||
|
assert _contains_self_unequal(NeverEqual()) is True
|
||||||
|
|
||||||
|
|
||||||
class TestEstimateValueSize:
|
class TestEstimateValueSize:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user