mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
fix: remove unused imports (ruff) and update tests for internal API
- Remove unused CacheContext and _serialize_cache_key imports from caching.py (now handled by _build_context helper) - Update test_cache_provider.py to use _-prefixed internal names - Update tests for new CacheContext.cache_key_hash field (str) - Make MockCacheProvider methods async to match ABC Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
4cbe4fe4c7
commit
da514866d6
@ -230,8 +230,7 @@ class BasicCache:
|
|||||||
"""Notify external providers of cache store (fire-and-forget)."""
|
"""Notify external providers of cache store (fire-and-forget)."""
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
_has_cache_providers, _get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheContext, CacheValue,
|
CacheValue, _contains_nan, _logger
|
||||||
_serialize_cache_key, _contains_nan, _logger
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fast exit conditions
|
# Fast exit conditions
|
||||||
@ -269,8 +268,7 @@ class BasicCache:
|
|||||||
"""Check external providers for cached result."""
|
"""Check external providers for cached result."""
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
_has_cache_providers, _get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheContext, CacheValue,
|
CacheValue, _contains_nan, _logger
|
||||||
_contains_nan, _logger
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_subcache:
|
if self._is_subcache:
|
||||||
|
|||||||
@ -16,12 +16,12 @@ from comfy_execution.cache_provider import (
|
|||||||
CacheValue,
|
CacheValue,
|
||||||
register_cache_provider,
|
register_cache_provider,
|
||||||
unregister_cache_provider,
|
unregister_cache_provider,
|
||||||
get_cache_providers,
|
_get_cache_providers,
|
||||||
has_cache_providers,
|
_has_cache_providers,
|
||||||
clear_cache_providers,
|
_clear_cache_providers,
|
||||||
serialize_cache_key,
|
_serialize_cache_key,
|
||||||
contains_nan,
|
_contains_nan,
|
||||||
estimate_value_size,
|
_estimate_value_size,
|
||||||
_canonicalize,
|
_canonicalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,15 +110,15 @@ class TestCanonicalize:
|
|||||||
|
|
||||||
|
|
||||||
class TestSerializeCacheKey:
|
class TestSerializeCacheKey:
|
||||||
"""Test serialize_cache_key for deterministic hashing."""
|
"""Test _serialize_cache_key for deterministic hashing."""
|
||||||
|
|
||||||
def test_same_content_same_hash(self):
|
def test_same_content_same_hash(self):
|
||||||
"""Same content should produce same hash."""
|
"""Same content should produce same hash."""
|
||||||
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||||
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||||
|
|
||||||
hash1 = serialize_cache_key(key1)
|
hash1 = _serialize_cache_key(key1)
|
||||||
hash2 = serialize_cache_key(key2)
|
hash2 = _serialize_cache_key(key2)
|
||||||
|
|
||||||
assert hash1 == hash2
|
assert hash1 == hash2
|
||||||
|
|
||||||
@ -127,18 +127,18 @@ class TestSerializeCacheKey:
|
|||||||
key1 = frozenset([("node_1", "value_a")])
|
key1 = frozenset([("node_1", "value_a")])
|
||||||
key2 = frozenset([("node_1", "value_b")])
|
key2 = frozenset([("node_1", "value_b")])
|
||||||
|
|
||||||
hash1 = serialize_cache_key(key1)
|
hash1 = _serialize_cache_key(key1)
|
||||||
hash2 = serialize_cache_key(key2)
|
hash2 = _serialize_cache_key(key2)
|
||||||
|
|
||||||
assert hash1 != hash2
|
assert hash1 != hash2
|
||||||
|
|
||||||
def test_returns_bytes(self):
|
def test_returns_hex_string(self):
|
||||||
"""Should return bytes (SHA256 digest)."""
|
"""Should return hex string (SHA256 hex digest)."""
|
||||||
key = frozenset([("test", 123)])
|
key = frozenset([("test", 123)])
|
||||||
result = serialize_cache_key(key)
|
result = _serialize_cache_key(key)
|
||||||
|
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, str)
|
||||||
assert len(result) == 32 # SHA256 produces 32 bytes
|
assert len(result) == 64 # SHA256 hex digest is 64 chars
|
||||||
|
|
||||||
def test_complex_nested_structure(self):
|
def test_complex_nested_structure(self):
|
||||||
"""Complex nested structures should hash deterministically."""
|
"""Complex nested structures should hash deterministically."""
|
||||||
@ -155,81 +155,81 @@ class TestSerializeCacheKey:
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Hash twice to verify determinism
|
# Hash twice to verify determinism
|
||||||
hash1 = serialize_cache_key(key)
|
hash1 = _serialize_cache_key(key)
|
||||||
hash2 = serialize_cache_key(key)
|
hash2 = _serialize_cache_key(key)
|
||||||
|
|
||||||
assert hash1 == hash2
|
assert hash1 == hash2
|
||||||
|
|
||||||
def test_dict_in_cache_key(self):
|
def test_dict_in_cache_key(self):
|
||||||
"""Dicts passed directly to serialize_cache_key should work."""
|
"""Dicts passed directly to _serialize_cache_key should work."""
|
||||||
# This tests the _canonicalize function's ability to handle dicts
|
# This tests the _canonicalize function's ability to handle dicts
|
||||||
key = {"node_1": {"input": "value"}, "node_2": 42}
|
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||||
|
|
||||||
hash1 = serialize_cache_key(key)
|
hash1 = _serialize_cache_key(key)
|
||||||
hash2 = serialize_cache_key(key)
|
hash2 = _serialize_cache_key(key)
|
||||||
|
|
||||||
assert hash1 == hash2
|
assert hash1 == hash2
|
||||||
assert isinstance(hash1, bytes)
|
assert isinstance(hash1, str)
|
||||||
assert len(hash1) == 32
|
assert len(hash1) == 64
|
||||||
|
|
||||||
|
|
||||||
class TestContainsNan:
|
class TestContainsNan:
|
||||||
"""Test contains_nan utility function."""
|
"""Test _contains_nan utility function."""
|
||||||
|
|
||||||
def test_nan_float_detected(self):
|
def test_nan_float_detected(self):
|
||||||
"""NaN floats should be detected."""
|
"""NaN floats should be detected."""
|
||||||
assert contains_nan(float('nan')) is True
|
assert _contains_nan(float('nan')) is True
|
||||||
|
|
||||||
def test_regular_float_not_nan(self):
|
def test_regular_float_not_nan(self):
|
||||||
"""Regular floats should not be detected as NaN."""
|
"""Regular floats should not be detected as NaN."""
|
||||||
assert contains_nan(3.14) is False
|
assert _contains_nan(3.14) is False
|
||||||
assert contains_nan(0.0) is False
|
assert _contains_nan(0.0) is False
|
||||||
assert contains_nan(-1.5) is False
|
assert _contains_nan(-1.5) is False
|
||||||
|
|
||||||
def test_infinity_not_nan(self):
|
def test_infinity_not_nan(self):
|
||||||
"""Infinity is not NaN."""
|
"""Infinity is not NaN."""
|
||||||
assert contains_nan(float('inf')) is False
|
assert _contains_nan(float('inf')) is False
|
||||||
assert contains_nan(float('-inf')) is False
|
assert _contains_nan(float('-inf')) is False
|
||||||
|
|
||||||
def test_nan_in_list(self):
|
def test_nan_in_list(self):
|
||||||
"""NaN in list should be detected."""
|
"""NaN in list should be detected."""
|
||||||
assert contains_nan([1, 2, float('nan'), 4]) is True
|
assert _contains_nan([1, 2, float('nan'), 4]) is True
|
||||||
assert contains_nan([1, 2, 3, 4]) is False
|
assert _contains_nan([1, 2, 3, 4]) is False
|
||||||
|
|
||||||
def test_nan_in_tuple(self):
|
def test_nan_in_tuple(self):
|
||||||
"""NaN in tuple should be detected."""
|
"""NaN in tuple should be detected."""
|
||||||
assert contains_nan((1, float('nan'))) is True
|
assert _contains_nan((1, float('nan'))) is True
|
||||||
assert contains_nan((1, 2, 3)) is False
|
assert _contains_nan((1, 2, 3)) is False
|
||||||
|
|
||||||
def test_nan_in_frozenset(self):
|
def test_nan_in_frozenset(self):
|
||||||
"""NaN in frozenset should be detected."""
|
"""NaN in frozenset should be detected."""
|
||||||
assert contains_nan(frozenset([1, float('nan')])) is True
|
assert _contains_nan(frozenset([1, float('nan')])) is True
|
||||||
assert contains_nan(frozenset([1, 2, 3])) is False
|
assert _contains_nan(frozenset([1, 2, 3])) is False
|
||||||
|
|
||||||
def test_nan_in_dict_value(self):
|
def test_nan_in_dict_value(self):
|
||||||
"""NaN in dict value should be detected."""
|
"""NaN in dict value should be detected."""
|
||||||
assert contains_nan({"key": float('nan')}) is True
|
assert _contains_nan({"key": float('nan')}) is True
|
||||||
assert contains_nan({"key": 42}) is False
|
assert _contains_nan({"key": 42}) is False
|
||||||
|
|
||||||
def test_nan_in_nested_structure(self):
|
def test_nan_in_nested_structure(self):
|
||||||
"""NaN in deeply nested structure should be detected."""
|
"""NaN in deeply nested structure should be detected."""
|
||||||
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||||
assert contains_nan(nested) is True
|
assert _contains_nan(nested) is True
|
||||||
|
|
||||||
def test_non_numeric_types(self):
|
def test_non_numeric_types(self):
|
||||||
"""Non-numeric types should not be NaN."""
|
"""Non-numeric types should not be NaN."""
|
||||||
assert contains_nan("string") is False
|
assert _contains_nan("string") is False
|
||||||
assert contains_nan(None) is False
|
assert _contains_nan(None) is False
|
||||||
assert contains_nan(True) is False
|
assert _contains_nan(True) is False
|
||||||
|
|
||||||
|
|
||||||
class TestEstimateValueSize:
|
class TestEstimateValueSize:
|
||||||
"""Test estimate_value_size utility function."""
|
"""Test _estimate_value_size utility function."""
|
||||||
|
|
||||||
def test_empty_outputs(self):
|
def test_empty_outputs(self):
|
||||||
"""Empty outputs should have zero size."""
|
"""Empty outputs should have zero size."""
|
||||||
value = CacheValue(outputs=[])
|
value = CacheValue(outputs=[])
|
||||||
assert estimate_value_size(value) == 0
|
assert _estimate_value_size(value) == 0
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not _torch_available(),
|
not _torch_available(),
|
||||||
@ -243,7 +243,7 @@ class TestEstimateValueSize:
|
|||||||
tensor = torch.zeros(1000, dtype=torch.float32)
|
tensor = torch.zeros(1000, dtype=torch.float32)
|
||||||
value = CacheValue(outputs=[[tensor]])
|
value = CacheValue(outputs=[[tensor]])
|
||||||
|
|
||||||
size = estimate_value_size(value)
|
size = _estimate_value_size(value)
|
||||||
assert size == 4000
|
assert size == 4000
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -257,7 +257,7 @@ class TestEstimateValueSize:
|
|||||||
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||||
value = CacheValue(outputs=[[{"samples": tensor}]])
|
value = CacheValue(outputs=[[{"samples": tensor}]])
|
||||||
|
|
||||||
size = estimate_value_size(value)
|
size = _estimate_value_size(value)
|
||||||
assert size == 400
|
assert size == 400
|
||||||
|
|
||||||
|
|
||||||
@ -266,19 +266,19 @@ class TestProviderRegistry:
|
|||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
"""Clear providers before each test."""
|
"""Clear providers before each test."""
|
||||||
clear_cache_providers()
|
_clear_cache_providers()
|
||||||
|
|
||||||
def teardown_method(self):
|
def teardown_method(self):
|
||||||
"""Clear providers after each test."""
|
"""Clear providers after each test."""
|
||||||
clear_cache_providers()
|
_clear_cache_providers()
|
||||||
|
|
||||||
def test_register_provider(self):
|
def test_register_provider(self):
|
||||||
"""Provider should be registered successfully."""
|
"""Provider should be registered successfully."""
|
||||||
provider = MockCacheProvider()
|
provider = MockCacheProvider()
|
||||||
register_cache_provider(provider)
|
register_cache_provider(provider)
|
||||||
|
|
||||||
assert has_cache_providers() is True
|
assert _has_cache_providers() is True
|
||||||
providers = get_cache_providers()
|
providers = _get_cache_providers()
|
||||||
assert len(providers) == 1
|
assert len(providers) == 1
|
||||||
assert providers[0] is provider
|
assert providers[0] is provider
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class TestProviderRegistry:
|
|||||||
register_cache_provider(provider)
|
register_cache_provider(provider)
|
||||||
unregister_cache_provider(provider)
|
unregister_cache_provider(provider)
|
||||||
|
|
||||||
assert has_cache_providers() is False
|
assert _has_cache_providers() is False
|
||||||
|
|
||||||
def test_multiple_providers(self):
|
def test_multiple_providers(self):
|
||||||
"""Multiple providers can be registered."""
|
"""Multiple providers can be registered."""
|
||||||
@ -298,7 +298,7 @@ class TestProviderRegistry:
|
|||||||
register_cache_provider(provider1)
|
register_cache_provider(provider1)
|
||||||
register_cache_provider(provider2)
|
register_cache_provider(provider2)
|
||||||
|
|
||||||
providers = get_cache_providers()
|
providers = _get_cache_providers()
|
||||||
assert len(providers) == 2
|
assert len(providers) == 2
|
||||||
|
|
||||||
def test_duplicate_registration_ignored(self):
|
def test_duplicate_registration_ignored(self):
|
||||||
@ -308,20 +308,20 @@ class TestProviderRegistry:
|
|||||||
register_cache_provider(provider)
|
register_cache_provider(provider)
|
||||||
register_cache_provider(provider) # Should be ignored
|
register_cache_provider(provider) # Should be ignored
|
||||||
|
|
||||||
providers = get_cache_providers()
|
providers = _get_cache_providers()
|
||||||
assert len(providers) == 1
|
assert len(providers) == 1
|
||||||
|
|
||||||
def test_clear_providers(self):
|
def test_clear_providers(self):
|
||||||
"""clear_cache_providers should remove all providers."""
|
"""_clear_cache_providers should remove all providers."""
|
||||||
provider1 = MockCacheProvider()
|
provider1 = MockCacheProvider()
|
||||||
provider2 = MockCacheProvider()
|
provider2 = MockCacheProvider()
|
||||||
|
|
||||||
register_cache_provider(provider1)
|
register_cache_provider(provider1)
|
||||||
register_cache_provider(provider2)
|
register_cache_provider(provider2)
|
||||||
clear_cache_providers()
|
_clear_cache_providers()
|
||||||
|
|
||||||
assert has_cache_providers() is False
|
assert _has_cache_providers() is False
|
||||||
assert len(get_cache_providers()) == 0
|
assert len(_get_cache_providers()) == 0
|
||||||
|
|
||||||
|
|
||||||
class TestCacheContext:
|
class TestCacheContext:
|
||||||
@ -333,15 +333,13 @@ class TestCacheContext:
|
|||||||
prompt_id="prompt-123",
|
prompt_id="prompt-123",
|
||||||
node_id="node-456",
|
node_id="node-456",
|
||||||
class_type="KSampler",
|
class_type="KSampler",
|
||||||
cache_key=frozenset([("test", "value")]),
|
cache_key_hash="abcdef1234567890" * 4,
|
||||||
cache_key_bytes=b"hash_bytes",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert context.prompt_id == "prompt-123"
|
assert context.prompt_id == "prompt-123"
|
||||||
assert context.node_id == "node-456"
|
assert context.node_id == "node-456"
|
||||||
assert context.class_type == "KSampler"
|
assert context.class_type == "KSampler"
|
||||||
assert context.cache_key == frozenset([("test", "value")])
|
assert context.cache_key_hash == "abcdef1234567890" * 4
|
||||||
assert context.cache_key_bytes == b"hash_bytes"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCacheValue:
|
class TestCacheValue:
|
||||||
@ -362,9 +360,9 @@ class MockCacheProvider(CacheProvider):
|
|||||||
self.lookups = []
|
self.lookups = []
|
||||||
self.stores = []
|
self.stores = []
|
||||||
|
|
||||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
self.lookups.append(context)
|
self.lookups.append(context)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
self.stores.append((context, value))
|
self.stores.append((context, value))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user