diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b0fa14ff6..1a1b6d162 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -106,6 +106,42 @@ class Types: MESH = MESH VOXEL = VOXEL + +class Caching: + """ + External cache provider API for distributed caching. + + Enables sharing cached results across multiple ComfyUI instances + (e.g., Kubernetes pods) without monkey-patching internal methods. + + Example usage: + from comfy_api.latest import Caching + + class MyRedisProvider(Caching.CacheProvider): + def on_lookup(self, context): + # Check Redis for cached result + ... + + def on_store(self, context, value): + # Store to Redis (can be async internally) + ... + + Caching.register_provider(MyRedisProvider()) + """ + # Import from comfy_execution.cache_provider (source of truth) + from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider as register_provider, + unregister_cache_provider as unregister_provider, + get_cache_providers as get_providers, + has_cache_providers as has_providers, + clear_cache_providers as clear_providers, + estimate_value_size, + ) + + ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -125,6 +161,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..a00e8cb15 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,319 @@ +""" +External Cache Provider API for distributed caching. + +This module provides a public API for external cache providers, enabling +distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). + +Public API is also available via: + from comfy_api.latest import Caching + +Example usage: + from comfy_execution.cache_provider import ( + CacheProvider, CacheContext, CacheValue, register_cache_provider + ) + + class MyRedisProvider(CacheProvider): + def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + # Check Redis/GCS for cached result + ... + + def on_store(self, context: CacheContext, value: CacheValue) -> None: + # Store to Redis/GCS (can be async internally) + ... + + register_cache_provider(MyRedisProvider()) +""" + +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple, List +from dataclasses import dataclass +import hashlib +import json +import logging +import math +import pickle +import threading + +logger = logging.getLogger(__name__) + + +# ============================================================ +# Data Classes +# ============================================================ + +@dataclass +class CacheContext: + """Context passed to provider methods.""" + prompt_id: str # Current prompt execution ID + node_id: str # Node being cached + class_type: str # Node class type (e.g., "KSampler") + cache_key: Any # Raw cache key (frozenset structure) + cache_key_bytes: bytes # SHA256 hash for external storage key + + +@dataclass +class CacheValue: + """ + Value stored/retrieved from external cache. + + The ui field is optional - implementations may choose to skip it + (e.g., if it contains non-portable data like local file paths). + """ + outputs: list # The tensor/value outputs + ui: dict = None # Optional UI data (may be skipped by implementations) + + +# ============================================================ +# Provider Interface +# ============================================================ + +class CacheProvider(ABC): + """ + Abstract base class for external cache providers. + + Thread Safety: + Providers may be called from multiple threads. Implementations + must be thread-safe. + + Error Handling: + All methods are wrapped in try/except by the caller. Exceptions + are logged but never propagate to break execution. + + Performance Guidelines: + - on_lookup: Should complete in <500ms (including network) + - on_store: Can be async internally (fire-and-forget) + - should_cache: Should be fast (<1ms), called frequently + """ + + @abstractmethod + def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """ + Check external storage for cached result. + + Called AFTER local cache miss (local-first for performance). + + Returns: + CacheValue if found externally, None otherwise. + + Important: + - Return None on any error (don't raise) + - Validate data integrity before returning + """ + pass + + @abstractmethod + def on_store(self, context: CacheContext, value: CacheValue) -> None: + """ + Store value to external cache. + + Called AFTER value is stored in local cache. + + Important: + - Can be fire-and-forget (async internally) + - Should never block execution + - Handle serialization failures gracefully + """ + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """ + Filter which nodes should be externally cached. + + Called before on_lookup (value=None) and on_store (value provided). + Return False to skip external caching for this node. + + Implementations can filter based on context.class_type, value size, + or any custom logic. Use estimate_value_size() to get value size. + + Default: Returns True (cache everything). + """ + return True + + def on_prompt_start(self, prompt_id: str) -> None: + """Called when prompt execution begins. Optional.""" + pass + + def on_prompt_end(self, prompt_id: str) -> None: + """Called when prompt execution ends. Optional.""" + pass + + +# ============================================================ +# Provider Registry +# ============================================================ + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None + + +def register_cache_provider(provider: CacheProvider) -> None: + """ + Register an external cache provider. + + Providers are called in registration order. First provider to return + a result from on_lookup wins. + """ + global _providers_snapshot + with _providers_lock: + if provider in _providers: + logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = None # Invalidate cache + logger.info(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + """Remove a previously registered provider.""" + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = None + logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def get_cache_providers() -> Tuple[CacheProvider, ...]: + """Get registered providers (cached for performance).""" + global _providers_snapshot + snapshot = _providers_snapshot + if snapshot is not None: + return snapshot + with _providers_lock: + if _providers_snapshot is not None: + return _providers_snapshot + _providers_snapshot = tuple(_providers) + return _providers_snapshot + + +def has_cache_providers() -> bool: + """Fast check if any providers registered (no lock).""" + return bool(_providers) + + +def clear_cache_providers() -> None: + """Remove all providers. Useful for testing.""" + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = None + + +# ============================================================ +# Utilities +# ============================================================ + +def _canonicalize(obj: Any) -> Any: + """ + Convert an object to a canonical, JSON-serializable form. + + This ensures deterministic ordering regardless of Python's hash randomization, + which is critical for cross-pod cache key consistency. Frozensets in particular + have non-deterministic iteration order between Python sessions. + """ + if isinstance(obj, frozenset): + # Sort frozenset items for deterministic ordering + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {str(k): _canonicalize(v) for k, v in sorted(obj.items())} + elif isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + elif hasattr(obj, 'value'): + # Handle Unhashable class from ComfyUI + return ("__unhashable__", _canonicalize(getattr(obj, 'value', None))) + else: + # For other types, use repr as fallback + return ("__repr__", repr(obj)) + + +def serialize_cache_key(cache_key: Any) -> bytes: + """ + Serialize cache key to bytes for external storage. + + Returns SHA256 hash suitable for Redis/database keys. + + Note: Uses canonicalize + JSON serialization instead of pickle because + pickle is NOT deterministic across Python sessions due to hash randomization + affecting frozenset iteration order. This is critical for distributed caching + where different pods need to compute the same hash for identical inputs. + """ + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).digest() + except Exception as e: + logger.warning(f"Failed to serialize cache key: {e}") + # Fallback to pickle (non-deterministic but better than nothing) + try: + serialized = pickle.dumps(cache_key, protocol=4) + return hashlib.sha256(serialized).digest() + except Exception: + return hashlib.sha256(str(id(cache_key)).encode()).digest() + + +def contains_nan(obj: Any) -> bool: + """ + Check if cache key contains NaN (indicates uncacheable node). + + NaN != NaN in Python, so local cache never hits. But serialized + NaN would match, causing incorrect external hits. Must skip these. + """ + if isinstance(obj, float): + try: + return math.isnan(obj) + except (TypeError, ValueError): + return False + if hasattr(obj, 'value'): # Unhashable class + val = getattr(obj, 'value', None) + if isinstance(val, float): + try: + return math.isnan(val) + except (TypeError, ValueError): + return False + if isinstance(obj, (frozenset, tuple, list, set)): + return any(contains_nan(item) for item in obj) + if isinstance(obj, dict): + return any(contains_nan(k) or contains_nan(v) for k, v in obj.items()) + return False + + +def estimate_value_size(value: CacheValue) -> int: + """Estimate serialized size in bytes. Useful for size-based filtering.""" + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..c666c4dc1 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -155,6 +155,10 @@ class BasicCache: self.cache = {} self.subcaches = {} + # External cache provider support + self._is_subcache = False + self._current_prompt_id = '' + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) @@ -201,20 +205,123 @@ class BasicCache: cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value + # Notify external providers + self._notify_providers_store(node_id, cache_key, value) + def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) + + # Check local cache first (fast path) if cache_key in self.cache: return self.cache[cache_key] - else: + + # Check external providers on local miss + external_result = self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result # Warm local cache + return external_result + + return None + + def _notify_providers_store(self, node_id, cache_key, value): + """Notify external providers of cache store.""" + from comfy_execution.cache_provider import ( + has_cache_providers, get_cache_providers, + CacheContext, CacheValue, + serialize_cache_key, contains_nan, logger + ) + + # Fast exit conditions + if self._is_subcache: + return + if not has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if contains_nan(cache_key): + return + + context = CacheContext( + prompt_id=self._current_prompt_id, + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key=cache_key, + cache_key_bytes=serialize_cache_key(cache_key) + ) + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + provider.on_store(context, cache_value) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + def _check_providers_lookup(self, node_id, cache_key): + """Check external providers for cached result.""" + from comfy_execution.cache_provider import ( + has_cache_providers, get_cache_providers, + CacheContext, CacheValue, + serialize_cache_key, contains_nan, logger + ) + + if self._is_subcache: return None + if not has_cache_providers(): + return None + if contains_nan(cache_key): + return None + + context = CacheContext( + prompt_id=self._current_prompt_id, + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key=cache_key, + cache_key_bytes=serialize_cache_key(cache_key) + ) + + for provider in get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + # Import CacheEntry here to avoid circular import at module level + from execution import CacheEntry + return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + """Check if value is a CacheEntry suitable for external caching (not objects cache).""" + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + """Get class_type for a node.""" + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' async def _ensure_subcache(self, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) + subcache._is_subcache = True # Mark as subcache - excludes from external caching + subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache diff --git a/execution.py b/execution.py index 3dbab82e6..248ac3da3 100644 --- a/execution.py +++ b/execution.py @@ -683,6 +683,22 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + """Notify external cache providers of prompt lifecycle events.""" + from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger + + if not has_cache_providers(): + return + + for provider in get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -699,66 +715,77 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + # Set prompt ID on caches for external provider integration + for cache in self.caches.all: + cache._current_prompt_id = prompt_id - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + # Notify external cache providers of prompt start + self._notify_prompt_lifecycle("start", prompt_id) - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + # Notify external cache providers of prompt end + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py new file mode 100644 index 000000000..c7484e1a1 --- /dev/null +++ b/tests-unit/execution_test/test_cache_provider.py @@ -0,0 +1,370 @@ +"""Tests for external cache provider API.""" + +import importlib.util +import pytest +from typing import Optional + + +def _torch_available() -> bool: + """Check if PyTorch is available.""" + return importlib.util.find_spec("torch") is not None + + +from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider, + unregister_cache_provider, + get_cache_providers, + has_cache_providers, + clear_cache_providers, + serialize_cache_key, + contains_nan, + estimate_value_size, + _canonicalize, +) + + +class TestCanonicalize: + """Test _canonicalize function for deterministic ordering.""" + + def test_frozenset_ordering_is_deterministic(self): + """Frozensets should produce consistent canonical form regardless of iteration order.""" + # Create two frozensets with same content + fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) + fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_nested_frozenset_ordering(self): + """Nested frozensets should also be deterministically ordered.""" + inner1 = frozenset([1, 2, 3]) + inner2 = frozenset([3, 2, 1]) + + fs1 = frozenset([("key", inner1)]) + fs2 = frozenset([("key", inner2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_dict_ordering(self): + """Dicts should be sorted by key.""" + d1 = {"z": 1, "a": 2, "m": 3} + d2 = {"a": 2, "m": 3, "z": 1} + + result1 = _canonicalize(d1) + result2 = _canonicalize(d2) + + assert result1 == result2 + + def test_tuple_preserved(self): + """Tuples should be marked and preserved.""" + t = (1, 2, 3) + result = _canonicalize(t) + + assert result[0] == "__tuple__" + assert result[1] == [1, 2, 3] + + def test_list_preserved(self): + """Lists should be recursively canonicalized.""" + lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] + result = _canonicalize(lst) + + # First element should be dict with sorted keys + assert result[0] == {"a": 1, "b": 2} + # Second element should be canonicalized frozenset + assert result[1][0] == "__frozenset__" + + def test_primitives_unchanged(self): + """Primitive types should pass through unchanged.""" + assert _canonicalize(42) == 42 + assert _canonicalize(3.14) == 3.14 + assert _canonicalize("hello") == "hello" + assert _canonicalize(True) is True + assert _canonicalize(None) is None + + def test_bytes_converted(self): + """Bytes should be converted to hex string.""" + b = b"\x00\xff" + result = _canonicalize(b) + + assert result[0] == "__bytes__" + assert result[1] == "00ff" + + def test_set_ordering(self): + """Sets should be sorted like frozensets.""" + s1 = {3, 1, 2} + s2 = {1, 2, 3} + + result1 = _canonicalize(s1) + result2 = _canonicalize(s2) + + assert result1 == result2 + assert result1[0] == "__set__" + + +class TestSerializeCacheKey: + """Test serialize_cache_key for deterministic hashing.""" + + def test_same_content_same_hash(self): + """Same content should produce same hash.""" + key1 = frozenset([("node_1", frozenset([("input", "value")]))]) + key2 = frozenset([("node_1", frozenset([("input", "value")]))]) + + hash1 = serialize_cache_key(key1) + hash2 = serialize_cache_key(key2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self): + """Different content should produce different hash.""" + key1 = frozenset([("node_1", "value_a")]) + key2 = frozenset([("node_1", "value_b")]) + + hash1 = serialize_cache_key(key1) + hash2 = serialize_cache_key(key2) + + assert hash1 != hash2 + + def test_returns_bytes(self): + """Should return bytes (SHA256 digest).""" + key = frozenset([("test", 123)]) + result = serialize_cache_key(key) + + assert isinstance(result, bytes) + assert len(result) == 32 # SHA256 produces 32 bytes + + def test_complex_nested_structure(self): + """Complex nested structures should hash deterministically.""" + # Note: frozensets can only contain hashable types, so we use + # nested frozensets of tuples to represent dict-like structures + key = frozenset([ + ("node_1", frozenset([ + ("input_a", ("tuple", "value")), + ("input_b", frozenset([("nested", "dict")])), + ])), + ("node_2", frozenset([ + ("param", 42), + ])), + ]) + + # Hash twice to verify determinism + hash1 = serialize_cache_key(key) + hash2 = serialize_cache_key(key) + + assert hash1 == hash2 + + def test_dict_in_cache_key(self): + """Dicts passed directly to serialize_cache_key should work.""" + # This tests the _canonicalize function's ability to handle dicts + key = {"node_1": {"input": "value"}, "node_2": 42} + + hash1 = serialize_cache_key(key) + hash2 = serialize_cache_key(key) + + assert hash1 == hash2 + assert isinstance(hash1, bytes) + assert len(hash1) == 32 + + +class TestContainsNan: + """Test contains_nan utility function.""" + + def test_nan_float_detected(self): + """NaN floats should be detected.""" + assert contains_nan(float('nan')) is True + + def test_regular_float_not_nan(self): + """Regular floats should not be detected as NaN.""" + assert contains_nan(3.14) is False + assert contains_nan(0.0) is False + assert contains_nan(-1.5) is False + + def test_infinity_not_nan(self): + """Infinity is not NaN.""" + assert contains_nan(float('inf')) is False + assert contains_nan(float('-inf')) is False + + def test_nan_in_list(self): + """NaN in list should be detected.""" + assert contains_nan([1, 2, float('nan'), 4]) is True + assert contains_nan([1, 2, 3, 4]) is False + + def test_nan_in_tuple(self): + """NaN in tuple should be detected.""" + assert contains_nan((1, float('nan'))) is True + assert contains_nan((1, 2, 3)) is False + + def test_nan_in_frozenset(self): + """NaN in frozenset should be detected.""" + assert contains_nan(frozenset([1, float('nan')])) is True + assert contains_nan(frozenset([1, 2, 3])) is False + + def test_nan_in_dict_value(self): + """NaN in dict value should be detected.""" + assert contains_nan({"key": float('nan')}) is True + assert contains_nan({"key": 42}) is False + + def test_nan_in_nested_structure(self): + """NaN in deeply nested structure should be detected.""" + nested = {"level1": [{"level2": (1, 2, float('nan'))}]} + assert contains_nan(nested) is True + + def test_non_numeric_types(self): + """Non-numeric types should not be NaN.""" + assert contains_nan("string") is False + assert contains_nan(None) is False + assert contains_nan(True) is False + + +class TestEstimateValueSize: + """Test estimate_value_size utility function.""" + + def test_empty_outputs(self): + """Empty outputs should have zero size.""" + value = CacheValue(outputs=[]) + assert estimate_value_size(value) == 0 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_tensor_size_estimation(self): + """Tensor size should be estimated correctly.""" + import torch + + # 1000 float32 elements = 4000 bytes + tensor = torch.zeros(1000, dtype=torch.float32) + value = CacheValue(outputs=[[tensor]]) + + size = estimate_value_size(value) + assert size == 4000 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_nested_tensor_in_dict(self): + """Tensors nested in dicts should be counted.""" + import torch + + tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes + value = CacheValue(outputs=[[{"samples": tensor}]]) + + size = estimate_value_size(value) + assert size == 400 + + +class TestProviderRegistry: + """Test cache provider registration and retrieval.""" + + def setup_method(self): + """Clear providers before each test.""" + clear_cache_providers() + + def teardown_method(self): + """Clear providers after each test.""" + clear_cache_providers() + + def test_register_provider(self): + """Provider should be registered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + + assert has_cache_providers() is True + providers = get_cache_providers() + assert len(providers) == 1 + assert providers[0] is provider + + def test_unregister_provider(self): + """Provider should be unregistered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + unregister_cache_provider(provider) + + assert has_cache_providers() is False + + def test_multiple_providers(self): + """Multiple providers can be registered.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + + providers = get_cache_providers() + assert len(providers) == 2 + + def test_duplicate_registration_ignored(self): + """Registering same provider twice should be ignored.""" + provider = MockCacheProvider() + + register_cache_provider(provider) + register_cache_provider(provider) # Should be ignored + + providers = get_cache_providers() + assert len(providers) == 1 + + def test_clear_providers(self): + """clear_cache_providers should remove all providers.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + clear_cache_providers() + + assert has_cache_providers() is False + assert len(get_cache_providers()) == 0 + + +class TestCacheContext: + """Test CacheContext dataclass.""" + + def test_context_creation(self): + """CacheContext should be created with all fields.""" + context = CacheContext( + prompt_id="prompt-123", + node_id="node-456", + class_type="KSampler", + cache_key=frozenset([("test", "value")]), + cache_key_bytes=b"hash_bytes", + ) + + assert context.prompt_id == "prompt-123" + assert context.node_id == "node-456" + assert context.class_type == "KSampler" + assert context.cache_key == frozenset([("test", "value")]) + assert context.cache_key_bytes == b"hash_bytes" + + +class TestCacheValue: + """Test CacheValue dataclass.""" + + def test_value_creation(self): + """CacheValue should be created with outputs.""" + outputs = [[{"samples": "tensor_data"}]] + value = CacheValue(outputs=outputs) + + assert value.outputs == outputs + + +class MockCacheProvider(CacheProvider): + """Mock cache provider for testing.""" + + def __init__(self): + self.lookups = [] + self.stores = [] + + def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + self.lookups.append(context) + return None + + def on_store(self, context: CacheContext, value: CacheValue) -> None: + self.stores.append((context, value))