""" External Cache Provider API for distributed caching. This module provides a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). Example usage: from comfy_execution.cache_provider import ( CacheProvider, CacheContext, CacheValue, register_cache_provider ) class MyRedisProvider(CacheProvider): def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: # Check Redis/GCS for cached result ... def on_store(self, context: CacheContext, value: CacheValue) -> None: # Store to Redis/GCS (can be async internally) ... register_cache_provider(MyRedisProvider()) """ from abc import ABC, abstractmethod from typing import Any, Optional, Tuple, List from dataclasses import dataclass import logging import threading import hashlib import pickle import math logger = logging.getLogger(__name__) # ============================================================ # Data Classes # ============================================================ @dataclass class CacheContext: """Context passed to provider methods.""" prompt_id: str # Current prompt execution ID node_id: str # Node being cached class_type: str # Node class type (e.g., "KSampler") cache_key: Any # Raw cache key (frozenset structure) cache_key_bytes: bytes # SHA256 hash for external storage key @dataclass class CacheValue: """ Value stored/retrieved from external cache. Note: UI data is intentionally excluded - it contains pod-local file paths that aren't portable across instances. """ outputs: list # The tensor/value outputs # ============================================================ # Provider Interface # ============================================================ class CacheProvider(ABC): """ Abstract base class for external cache providers. Thread Safety: Providers may be called from multiple threads. Implementations must be thread-safe. Error Handling: All methods are wrapped in try/except by the caller. Exceptions are logged but never propagate to break execution. Performance Guidelines: - on_lookup: Should complete in <500ms (including network) - on_store: Can be async internally (fire-and-forget) - should_cache: Should be fast (<1ms), called frequently """ @abstractmethod def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: """ Check external storage for cached result. Called AFTER local cache miss (local-first for performance). Returns: CacheValue if found externally, None otherwise. Important: - Return None on any error (don't raise) - Validate data integrity before returning """ pass @abstractmethod def on_store(self, context: CacheContext, value: CacheValue) -> None: """ Store value to external cache. Called AFTER value is stored in local cache. Important: - Can be fire-and-forget (async internally) - Should never block execution - Handle serialization failures gracefully """ pass def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: """ Filter which nodes should be externally cached. Called before on_lookup (value=None) and on_store (value provided). Return False to skip external caching for this node. Common filters: - By class_type: Only expensive nodes (KSampler, VAEDecode) - By size: Skip small values (< 1MB) Default: Returns True (cache everything). """ return True def on_prompt_start(self, prompt_id: str) -> None: """Called when prompt execution begins. Optional.""" pass def on_prompt_end(self, prompt_id: str) -> None: """Called when prompt execution ends. Optional.""" pass # ============================================================ # Provider Registry # ============================================================ _providers: List[CacheProvider] = [] _providers_lock = threading.Lock() _providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None def register_cache_provider(provider: CacheProvider) -> None: """ Register an external cache provider. Providers are called in registration order. First provider to return a result from on_lookup wins. """ global _providers_snapshot with _providers_lock: if provider in _providers: logger.warning(f"Provider {provider.__class__.__name__} already registered") return _providers.append(provider) _providers_snapshot = None # Invalidate cache logger.info(f"Registered cache provider: {provider.__class__.__name__}") def unregister_cache_provider(provider: CacheProvider) -> None: """Remove a previously registered provider.""" global _providers_snapshot with _providers_lock: try: _providers.remove(provider) _providers_snapshot = None logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") except ValueError: logger.warning(f"Provider {provider.__class__.__name__} was not registered") def get_cache_providers() -> Tuple[CacheProvider, ...]: """Get registered providers (cached for performance).""" global _providers_snapshot snapshot = _providers_snapshot if snapshot is not None: return snapshot with _providers_lock: if _providers_snapshot is not None: return _providers_snapshot _providers_snapshot = tuple(_providers) return _providers_snapshot def has_cache_providers() -> bool: """Fast check if any providers registered (no lock).""" return bool(_providers) def clear_cache_providers() -> None: """Remove all providers. Useful for testing.""" global _providers_snapshot with _providers_lock: _providers.clear() _providers_snapshot = None # ============================================================ # Utilities # ============================================================ def serialize_cache_key(cache_key: Any) -> bytes: """ Serialize cache key to bytes for external storage. Returns SHA256 hash suitable for Redis/database keys. """ try: serialized = pickle.dumps(cache_key, protocol=4) return hashlib.sha256(serialized).digest() except Exception as e: logger.warning(f"Failed to serialize cache key: {e}") return hashlib.sha256(str(id(cache_key)).encode()).digest() def contains_nan(obj: Any) -> bool: """ Check if cache key contains NaN (indicates uncacheable node). NaN != NaN in Python, so local cache never hits. But serialized NaN would match, causing incorrect external hits. Must skip these. """ if isinstance(obj, float): try: return math.isnan(obj) except (TypeError, ValueError): return False if hasattr(obj, 'value'): # Unhashable class val = getattr(obj, 'value', None) if isinstance(val, float): try: return math.isnan(val) except (TypeError, ValueError): return False if isinstance(obj, (frozenset, tuple, list, set)): return any(contains_nan(item) for item in obj) if isinstance(obj, dict): return any(contains_nan(k) or contains_nan(v) for k, v in obj.items()) return False def estimate_value_size(value: CacheValue) -> int: """Estimate serialized size in bytes. Useful for size-based filtering.""" try: import torch except ImportError: return 0 total = 0 def estimate(obj): nonlocal total if isinstance(obj, torch.Tensor): total += obj.numel() * obj.element_size() elif isinstance(obj, dict): for v in obj.values(): estimate(v) elif isinstance(obj, (list, tuple)): for item in obj: estimate(item) for output in value.outputs: estimate(output) return total