fix: remove unused imports (ruff) and update tests for internal API

- Remove unused CacheContext and _serialize_cache_key imports from
  caching.py (now handled by _build_context helper)
- Update test_cache_provider.py to use _-prefixed internal names
- Update tests for new CacheContext.cache_key_hash field (str)
- Make MockCacheProvider methods async to match ABC

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-03 12:51:23 -08:00
parent 4cbe4fe4c7
commit da514866d6
2 changed files with 63 additions and 67 deletions

View File

@ -230,8 +230,7 @@ class BasicCache:
"""Notify external providers of cache store (fire-and-forget).""" """Notify external providers of cache store (fire-and-forget)."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheContext, CacheValue, CacheValue, _contains_nan, _logger
_serialize_cache_key, _contains_nan, _logger
) )
# Fast exit conditions # Fast exit conditions
@ -269,8 +268,7 @@ class BasicCache:
"""Check external providers for cached result.""" """Check external providers for cached result."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheContext, CacheValue, CacheValue, _contains_nan, _logger
_contains_nan, _logger
) )
if self._is_subcache: if self._is_subcache:

View File

@ -16,12 +16,12 @@ from comfy_execution.cache_provider import (
CacheValue, CacheValue,
register_cache_provider, register_cache_provider,
unregister_cache_provider, unregister_cache_provider,
get_cache_providers, _get_cache_providers,
has_cache_providers, _has_cache_providers,
clear_cache_providers, _clear_cache_providers,
serialize_cache_key, _serialize_cache_key,
contains_nan, _contains_nan,
estimate_value_size, _estimate_value_size,
_canonicalize, _canonicalize,
) )
@ -110,15 +110,15 @@ class TestCanonicalize:
class TestSerializeCacheKey: class TestSerializeCacheKey:
"""Test serialize_cache_key for deterministic hashing.""" """Test _serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self): def test_same_content_same_hash(self):
"""Same content should produce same hash.""" """Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))]) key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))]) key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = serialize_cache_key(key1) hash1 = _serialize_cache_key(key1)
hash2 = serialize_cache_key(key2) hash2 = _serialize_cache_key(key2)
assert hash1 == hash2 assert hash1 == hash2
@ -127,18 +127,18 @@ class TestSerializeCacheKey:
key1 = frozenset([("node_1", "value_a")]) key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")]) key2 = frozenset([("node_1", "value_b")])
hash1 = serialize_cache_key(key1) hash1 = _serialize_cache_key(key1)
hash2 = serialize_cache_key(key2) hash2 = _serialize_cache_key(key2)
assert hash1 != hash2 assert hash1 != hash2
def test_returns_bytes(self): def test_returns_hex_string(self):
"""Should return bytes (SHA256 digest).""" """Should return hex string (SHA256 hex digest)."""
key = frozenset([("test", 123)]) key = frozenset([("test", 123)])
result = serialize_cache_key(key) result = _serialize_cache_key(key)
assert isinstance(result, bytes) assert isinstance(result, str)
assert len(result) == 32 # SHA256 produces 32 bytes assert len(result) == 64 # SHA256 hex digest is 64 chars
def test_complex_nested_structure(self): def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically.""" """Complex nested structures should hash deterministically."""
@ -155,81 +155,81 @@ class TestSerializeCacheKey:
]) ])
# Hash twice to verify determinism # Hash twice to verify determinism
hash1 = serialize_cache_key(key) hash1 = _serialize_cache_key(key)
hash2 = serialize_cache_key(key) hash2 = _serialize_cache_key(key)
assert hash1 == hash2 assert hash1 == hash2
def test_dict_in_cache_key(self): def test_dict_in_cache_key(self):
"""Dicts passed directly to serialize_cache_key should work.""" """Dicts passed directly to _serialize_cache_key should work."""
# This tests the _canonicalize function's ability to handle dicts # This tests the _canonicalize function's ability to handle dicts
key = {"node_1": {"input": "value"}, "node_2": 42} key = {"node_1": {"input": "value"}, "node_2": 42}
hash1 = serialize_cache_key(key) hash1 = _serialize_cache_key(key)
hash2 = serialize_cache_key(key) hash2 = _serialize_cache_key(key)
assert hash1 == hash2 assert hash1 == hash2
assert isinstance(hash1, bytes) assert isinstance(hash1, str)
assert len(hash1) == 32 assert len(hash1) == 64
class TestContainsNan: class TestContainsNan:
"""Test contains_nan utility function.""" """Test _contains_nan utility function."""
def test_nan_float_detected(self): def test_nan_float_detected(self):
"""NaN floats should be detected.""" """NaN floats should be detected."""
assert contains_nan(float('nan')) is True assert _contains_nan(float('nan')) is True
def test_regular_float_not_nan(self): def test_regular_float_not_nan(self):
"""Regular floats should not be detected as NaN.""" """Regular floats should not be detected as NaN."""
assert contains_nan(3.14) is False assert _contains_nan(3.14) is False
assert contains_nan(0.0) is False assert _contains_nan(0.0) is False
assert contains_nan(-1.5) is False assert _contains_nan(-1.5) is False
def test_infinity_not_nan(self): def test_infinity_not_nan(self):
"""Infinity is not NaN.""" """Infinity is not NaN."""
assert contains_nan(float('inf')) is False assert _contains_nan(float('inf')) is False
assert contains_nan(float('-inf')) is False assert _contains_nan(float('-inf')) is False
def test_nan_in_list(self): def test_nan_in_list(self):
"""NaN in list should be detected.""" """NaN in list should be detected."""
assert contains_nan([1, 2, float('nan'), 4]) is True assert _contains_nan([1, 2, float('nan'), 4]) is True
assert contains_nan([1, 2, 3, 4]) is False assert _contains_nan([1, 2, 3, 4]) is False
def test_nan_in_tuple(self): def test_nan_in_tuple(self):
"""NaN in tuple should be detected.""" """NaN in tuple should be detected."""
assert contains_nan((1, float('nan'))) is True assert _contains_nan((1, float('nan'))) is True
assert contains_nan((1, 2, 3)) is False assert _contains_nan((1, 2, 3)) is False
def test_nan_in_frozenset(self): def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected.""" """NaN in frozenset should be detected."""
assert contains_nan(frozenset([1, float('nan')])) is True assert _contains_nan(frozenset([1, float('nan')])) is True
assert contains_nan(frozenset([1, 2, 3])) is False assert _contains_nan(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self): def test_nan_in_dict_value(self):
"""NaN in dict value should be detected.""" """NaN in dict value should be detected."""
assert contains_nan({"key": float('nan')}) is True assert _contains_nan({"key": float('nan')}) is True
assert contains_nan({"key": 42}) is False assert _contains_nan({"key": 42}) is False
def test_nan_in_nested_structure(self): def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected.""" """NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]} nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert contains_nan(nested) is True assert _contains_nan(nested) is True
def test_non_numeric_types(self): def test_non_numeric_types(self):
"""Non-numeric types should not be NaN.""" """Non-numeric types should not be NaN."""
assert contains_nan("string") is False assert _contains_nan("string") is False
assert contains_nan(None) is False assert _contains_nan(None) is False
assert contains_nan(True) is False assert _contains_nan(True) is False
class TestEstimateValueSize: class TestEstimateValueSize:
"""Test estimate_value_size utility function.""" """Test _estimate_value_size utility function."""
def test_empty_outputs(self): def test_empty_outputs(self):
"""Empty outputs should have zero size.""" """Empty outputs should have zero size."""
value = CacheValue(outputs=[]) value = CacheValue(outputs=[])
assert estimate_value_size(value) == 0 assert _estimate_value_size(value) == 0
@pytest.mark.skipif( @pytest.mark.skipif(
not _torch_available(), not _torch_available(),
@ -243,7 +243,7 @@ class TestEstimateValueSize:
tensor = torch.zeros(1000, dtype=torch.float32) tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]]) value = CacheValue(outputs=[[tensor]])
size = estimate_value_size(value) size = _estimate_value_size(value)
assert size == 4000 assert size == 4000
@pytest.mark.skipif( @pytest.mark.skipif(
@ -257,7 +257,7 @@ class TestEstimateValueSize:
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]]) value = CacheValue(outputs=[[{"samples": tensor}]])
size = estimate_value_size(value) size = _estimate_value_size(value)
assert size == 400 assert size == 400
@ -266,19 +266,19 @@ class TestProviderRegistry:
def setup_method(self): def setup_method(self):
"""Clear providers before each test.""" """Clear providers before each test."""
clear_cache_providers() _clear_cache_providers()
def teardown_method(self): def teardown_method(self):
"""Clear providers after each test.""" """Clear providers after each test."""
clear_cache_providers() _clear_cache_providers()
def test_register_provider(self): def test_register_provider(self):
"""Provider should be registered successfully.""" """Provider should be registered successfully."""
provider = MockCacheProvider() provider = MockCacheProvider()
register_cache_provider(provider) register_cache_provider(provider)
assert has_cache_providers() is True assert _has_cache_providers() is True
providers = get_cache_providers() providers = _get_cache_providers()
assert len(providers) == 1 assert len(providers) == 1
assert providers[0] is provider assert providers[0] is provider
@ -288,7 +288,7 @@ class TestProviderRegistry:
register_cache_provider(provider) register_cache_provider(provider)
unregister_cache_provider(provider) unregister_cache_provider(provider)
assert has_cache_providers() is False assert _has_cache_providers() is False
def test_multiple_providers(self): def test_multiple_providers(self):
"""Multiple providers can be registered.""" """Multiple providers can be registered."""
@ -298,7 +298,7 @@ class TestProviderRegistry:
register_cache_provider(provider1) register_cache_provider(provider1)
register_cache_provider(provider2) register_cache_provider(provider2)
providers = get_cache_providers() providers = _get_cache_providers()
assert len(providers) == 2 assert len(providers) == 2
def test_duplicate_registration_ignored(self): def test_duplicate_registration_ignored(self):
@ -308,20 +308,20 @@ class TestProviderRegistry:
register_cache_provider(provider) register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored register_cache_provider(provider) # Should be ignored
providers = get_cache_providers() providers = _get_cache_providers()
assert len(providers) == 1 assert len(providers) == 1
def test_clear_providers(self): def test_clear_providers(self):
"""clear_cache_providers should remove all providers.""" """_clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider() provider1 = MockCacheProvider()
provider2 = MockCacheProvider() provider2 = MockCacheProvider()
register_cache_provider(provider1) register_cache_provider(provider1)
register_cache_provider(provider2) register_cache_provider(provider2)
clear_cache_providers() _clear_cache_providers()
assert has_cache_providers() is False assert _has_cache_providers() is False
assert len(get_cache_providers()) == 0 assert len(_get_cache_providers()) == 0
class TestCacheContext: class TestCacheContext:
@ -333,15 +333,13 @@ class TestCacheContext:
prompt_id="prompt-123", prompt_id="prompt-123",
node_id="node-456", node_id="node-456",
class_type="KSampler", class_type="KSampler",
cache_key=frozenset([("test", "value")]), cache_key_hash="abcdef1234567890" * 4,
cache_key_bytes=b"hash_bytes",
) )
assert context.prompt_id == "prompt-123" assert context.prompt_id == "prompt-123"
assert context.node_id == "node-456" assert context.node_id == "node-456"
assert context.class_type == "KSampler" assert context.class_type == "KSampler"
assert context.cache_key == frozenset([("test", "value")]) assert context.cache_key_hash == "abcdef1234567890" * 4
assert context.cache_key_bytes == b"hash_bytes"
class TestCacheValue: class TestCacheValue:
@ -362,9 +360,9 @@ class MockCacheProvider(CacheProvider):
self.lookups = [] self.lookups = []
self.stores = [] self.stores = []
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context) self.lookups.append(context)
return None return None
def on_store(self, context: CacheContext, value: CacheValue) -> None: async def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value)) self.stores.append((context, value))