ComfyUI/comfy_execution/cache_provider.py
Deep Mehta 6540aa0400 feat: Add CacheProvider API for external distributed caching
Introduces a public API for external cache providers, enabling distributed
caching across multiple ComfyUI instances (e.g., Kubernetes pods).

New files:
- comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue
  dataclasses, thread-safe provider registry, serialization utilities

Modified files:
- comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store,
  _check_providers_lookup), subcache exclusion, prompt ID propagation
- execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to
  PromptExecutor, set _current_prompt_id on caches

Key features:
- Local-first caching (check local before external for performance)
- NaN detection to prevent incorrect external cache hits
- Subcache exclusion (ephemeral subgraph results not cached externally)
- Thread-safe provider snapshot caching
- Graceful error handling (provider errors logged, never break execution)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-19 16:43:13 +05:30

268 lines
8.2 KiB
Python

"""
External Cache Provider API for distributed caching.
This module provides a public API for external cache providers, enabling
distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods).
Example usage:
from comfy_execution.cache_provider import (
CacheProvider, CacheContext, CacheValue, register_cache_provider
)
class MyRedisProvider(CacheProvider):
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
# Check Redis/GCS for cached result
...
def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store to Redis/GCS (can be async internally)
...
register_cache_provider(MyRedisProvider())
"""
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List
from dataclasses import dataclass
import logging
import threading
import hashlib
import pickle
import math
logger = logging.getLogger(__name__)
# ============================================================
# Data Classes
# ============================================================
@dataclass
class CacheContext:
"""Context passed to provider methods."""
prompt_id: str # Current prompt execution ID
node_id: str # Node being cached
class_type: str # Node class type (e.g., "KSampler")
cache_key: Any # Raw cache key (frozenset structure)
cache_key_bytes: bytes # SHA256 hash for external storage key
@dataclass
class CacheValue:
"""
Value stored/retrieved from external cache.
Note: UI data is intentionally excluded - it contains pod-local
file paths that aren't portable across instances.
"""
outputs: list # The tensor/value outputs
# ============================================================
# Provider Interface
# ============================================================
class CacheProvider(ABC):
"""
Abstract base class for external cache providers.
Thread Safety:
Providers may be called from multiple threads. Implementations
must be thread-safe.
Error Handling:
All methods are wrapped in try/except by the caller. Exceptions
are logged but never propagate to break execution.
Performance Guidelines:
- on_lookup: Should complete in <500ms (including network)
- on_store: Can be async internally (fire-and-forget)
- should_cache: Should be fast (<1ms), called frequently
"""
@abstractmethod
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""
Check external storage for cached result.
Called AFTER local cache miss (local-first for performance).
Returns:
CacheValue if found externally, None otherwise.
Important:
- Return None on any error (don't raise)
- Validate data integrity before returning
"""
pass
@abstractmethod
def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""
Store value to external cache.
Called AFTER value is stored in local cache.
Important:
- Can be fire-and-forget (async internally)
- Should never block execution
- Handle serialization failures gracefully
"""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""
Filter which nodes should be externally cached.
Called before on_lookup (value=None) and on_store (value provided).
Return False to skip external caching for this node.
Common filters:
- By class_type: Only expensive nodes (KSampler, VAEDecode)
- By size: Skip small values (< 1MB)
Default: Returns True (cache everything).
"""
return True
def on_prompt_start(self, prompt_id: str) -> None:
"""Called when prompt execution begins. Optional."""
pass
def on_prompt_end(self, prompt_id: str) -> None:
"""Called when prompt execution ends. Optional."""
pass
# ============================================================
# Provider Registry
# ============================================================
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
def register_cache_provider(provider: CacheProvider) -> None:
"""
Register an external cache provider.
Providers are called in registration order. First provider to return
a result from on_lookup wins.
"""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = None # Invalidate cache
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
"""Remove a previously registered provider."""
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = None
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance)."""
global _providers_snapshot
snapshot = _providers_snapshot
if snapshot is not None:
return snapshot
with _providers_lock:
if _providers_snapshot is not None:
return _providers_snapshot
_providers_snapshot = tuple(_providers)
return _providers_snapshot
def has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock)."""
return bool(_providers)
def clear_cache_providers() -> None:
"""Remove all providers. Useful for testing."""
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = None
# ============================================================
# Utilities
# ============================================================
def serialize_cache_key(cache_key: Any) -> bytes:
"""
Serialize cache key to bytes for external storage.
Returns SHA256 hash suitable for Redis/database keys.
"""
try:
serialized = pickle.dumps(cache_key, protocol=4)
return hashlib.sha256(serialized).digest()
except Exception as e:
logger.warning(f"Failed to serialize cache key: {e}")
return hashlib.sha256(str(id(cache_key)).encode()).digest()
def contains_nan(obj: Any) -> bool:
"""
Check if cache key contains NaN (indicates uncacheable node).
NaN != NaN in Python, so local cache never hits. But serialized
NaN would match, causing incorrect external hits. Must skip these.
"""
if isinstance(obj, float):
try:
return math.isnan(obj)
except (TypeError, ValueError):
return False
if hasattr(obj, 'value'): # Unhashable class
val = getattr(obj, 'value', None)
if isinstance(val, float):
try:
return math.isnan(val)
except (TypeError, ValueError):
return False
if isinstance(obj, (frozenset, tuple, list, set)):
return any(contains_nan(item) for item in obj)
if isinstance(obj, dict):
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
return False
def estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering."""
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total