diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 4e3035f43..7c91d44f8 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -118,25 +118,6 @@ class Types: class Caching: - """ - External cache provider API. - - Enables sharing cached results across multiple ComfyUI instances. - - Example usage: - from comfy_api.latest import Caching - - class MyCacheProvider(Caching.CacheProvider): - async def on_lookup(self, context): - # Check external storage for cached result - ... - - async def on_store(self, context, value): - # Store result to external storage - ... - - Caching.register_provider(MyCacheProvider()) - """ # Import from comfy_execution.cache_provider (source of truth) from comfy_execution.cache_provider import ( CacheProvider, diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py index ed1f907ae..fe27e6d85 100644 --- a/comfy_execution/cache_provider.py +++ b/comfy_execution/cache_provider.py @@ -1,29 +1,3 @@ -""" -External Cache Provider API. - -This module provides a public API for external cache providers, enabling -shared caching across multiple ComfyUI instances. - -Public API is also available via: - from comfy_api.latest import Caching - -Example usage: - from comfy_execution.cache_provider import ( - CacheProvider, CacheContext, CacheValue, register_cache_provider - ) - - class MyCacheProvider(CacheProvider): - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - # Check external storage for cached result - ... - - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - # Store result to external storage - ... - - register_cache_provider(MyCacheProvider()) -""" - from abc import ABC, abstractmethod from typing import Any, Optional, Tuple, List from dataclasses import dataclass @@ -37,134 +11,64 @@ import threading _logger = logging.getLogger(__name__) -# ============================================================ -# Data Classes -# ============================================================ - @dataclass class CacheContext: - """Context passed to provider methods.""" - prompt_id: str # Current prompt execution ID - node_id: str # Node being cached - class_type: str # Node class type (e.g., "KSampler") - cache_key_hash: str # SHA256 hex digest for external storage key + prompt_id: str + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest @dataclass class CacheValue: - """ - Value stored/retrieved from external cache. + outputs: list + ui: dict = None - The ui field is optional - implementations may choose to skip it - (e.g., if it contains non-portable data like local file paths). - """ - outputs: list # The tensor/value outputs - ui: dict = None # Optional UI data (may be skipped by implementations) - - -# ============================================================ -# Provider Interface -# ============================================================ class CacheProvider(ABC): - """ - Abstract base class for external cache providers. - - Async Safety: - Provider methods are called from async context. Implementations - can use async I/O directly. - - Error Handling: - All methods are wrapped in try/except by the caller. Exceptions - are logged but never propagate to break execution. - - Performance Guidelines: - - on_lookup: Should complete quickly (including any network I/O) - - on_store: Dispatched via asyncio.create_task (non-blocking) - - should_cache: Should be fast, called frequently + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. """ @abstractmethod async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - """ - Check external storage for cached result. - - Called AFTER local cache miss (local-first for performance). - - Returns: - CacheValue if found externally, None otherwise. - - Important: - - Return None on any error (don't raise) - - Validate data integrity before returning - """ + """Called on local cache miss. Return CacheValue if found, None otherwise.""" pass @abstractmethod async def on_store(self, context: CacheContext, value: CacheValue) -> None: - """ - Store value to external cache. - - Called AFTER value is stored in local cache. - Dispatched via asyncio.create_task (non-blocking). - - Important: - - Should not block execution - - Handle serialization failures gracefully - """ + """Called after local store. Dispatched via asyncio.create_task.""" pass def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: - """ - Filter which nodes should be externally cached. - - Called before on_lookup (value=None) and on_store (value provided). - Return False to skip external caching for this node. - - Implementations can filter based on context.class_type, value size, - or any custom logic. Use _estimate_value_size() to get value size. - - Default: Returns True (cache everything). - """ + """Return False to skip external caching for this node. Default: True.""" return True def on_prompt_start(self, prompt_id: str) -> None: - """Called when prompt execution begins. Optional.""" pass def on_prompt_end(self, prompt_id: str) -> None: - """Called when prompt execution ends. Optional.""" pass -# ============================================================ -# Provider Registry -# ============================================================ - _providers: List[CacheProvider] = [] _providers_lock = threading.Lock() _providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None def register_cache_provider(provider: CacheProvider) -> None: - """ - Register an external cache provider. - - Providers are called in registration order. First provider to return - a result from on_lookup wins. - """ + """Register an external cache provider. Providers are called in registration order.""" global _providers_snapshot with _providers_lock: if provider in _providers: _logger.warning(f"Provider {provider.__class__.__name__} already registered") return _providers.append(provider) - _providers_snapshot = None # Invalidate cache + _providers_snapshot = None _logger.info(f"Registered cache provider: {provider.__class__.__name__}") def unregister_cache_provider(provider: CacheProvider) -> None: - """Remove a previously registered provider.""" global _providers_snapshot with _providers_lock: try: @@ -176,7 +80,6 @@ def unregister_cache_provider(provider: CacheProvider) -> None: def _get_cache_providers() -> Tuple[CacheProvider, ...]: - """Get registered providers (cached for performance). Internal.""" global _providers_snapshot snapshot = _providers_snapshot if snapshot is not None: @@ -189,30 +92,19 @@ def _get_cache_providers() -> Tuple[CacheProvider, ...]: def _has_cache_providers() -> bool: - """Fast check if any providers registered (no lock). Internal.""" return bool(_providers) def _clear_cache_providers() -> None: - """Remove all providers. Useful for testing. Internal.""" global _providers_snapshot with _providers_lock: _providers.clear() _providers_snapshot = None -# ============================================================ -# Internal Utilities -# ============================================================ - def _canonicalize(obj: Any) -> Any: - """ - Convert an object to a canonical, JSON-serializable form. - - This ensures deterministic ordering regardless of Python's hash randomization. - Frozensets in particular have non-deterministic iteration order between - Python sessions, so consistent serialization requires explicit sorting. - """ + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. if isinstance(obj, frozenset): # Sort frozenset items for deterministic ordering return ("__frozenset__", sorted( @@ -243,17 +135,8 @@ def _canonicalize(obj: Any) -> Any: def _serialize_cache_key(cache_key: Any) -> Optional[str]: - """ - Serialize cache key to a hex digest string for external storage. - - Returns SHA256 hex string suitable as an external storage key, - or None if serialization fails entirely (fail-closed). - - Note: Uses canonicalize + JSON serialization instead of pickle because - pickle is NOT deterministic across Python sessions due to hash randomization - affecting frozenset iteration order. Consistent hashing is required so that - different instances compute the same key for identical inputs. - """ + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. try: canonical = _canonicalize(cache_key) json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) @@ -270,12 +153,8 @@ def _serialize_cache_key(cache_key: Any) -> Optional[str]: def _contains_nan(obj: Any) -> bool: - """ - Check if cache key contains NaN (indicates uncacheable node). - - NaN != NaN in Python, so local cache never hits. But serialized - NaN would match, causing incorrect external hits. Must skip these. - """ + # NaN != NaN so local cache never hits, but serialized NaN would match. + # Skip external caching for keys containing NaN. if isinstance(obj, float): try: return math.isnan(obj) @@ -296,7 +175,6 @@ def _contains_nan(obj: Any) -> bool: def _estimate_value_size(value: CacheValue) -> int: - """Estimate serialized size in bytes. Useful for size-based filtering. Internal.""" try: import torch except ImportError: diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index d7c1f72b3..3b987846b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -156,7 +156,6 @@ class BasicCache: self.cache = {} self.subcaches = {} - # External cache provider support self._is_subcache = False self._current_prompt_id = '' @@ -202,7 +201,6 @@ class BasicCache: pass def get_local(self, node_id): - """Sync local-only cache lookup (no external providers).""" if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) @@ -211,7 +209,6 @@ class BasicCache: return None def set_local(self, node_id, value): - """Sync local-only cache store (no external providers).""" assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value @@ -221,7 +218,6 @@ class BasicCache: cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value - # Notify external providers await self._notify_providers_store(node_id, cache_key, value) async def _get_immediate(self, node_id): @@ -229,26 +225,22 @@ class BasicCache: return None cache_key = self.cache_key_set.get_data_key(node_id) - # Check local cache first (fast path) if cache_key in self.cache: return self.cache[cache_key] - # Check external providers on local miss external_result = await self._check_providers_lookup(node_id, cache_key) if external_result is not None: - self.cache[cache_key] = external_result # Warm local cache + self.cache[cache_key] = external_result return external_result return None async def _notify_providers_store(self, node_id, cache_key, value): - """Notify external providers of cache store (non-blocking).""" from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, CacheValue, _contains_nan, _logger ) - # Fast exit conditions if self._is_subcache: return if not _has_cache_providers(): @@ -272,7 +264,6 @@ class BasicCache: @staticmethod async def _safe_provider_store(provider, context, cache_value): - """Wrapper for async provider.on_store with error handling.""" from comfy_execution.cache_provider import _logger try: await provider.on_store(context, cache_value) @@ -280,7 +271,6 @@ class BasicCache: _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") async def _check_providers_lookup(self, node_id, cache_key): - """Check external providers for cached result.""" from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, CacheValue, _contains_nan, _logger @@ -309,7 +299,6 @@ class BasicCache: if not isinstance(result.outputs, (list, tuple)): _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") continue - # Import CacheEntry here to avoid circular import at module level from execution import CacheEntry return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) except Exception as e: @@ -318,11 +307,9 @@ class BasicCache: return None def _is_external_cacheable_value(self, value): - """Check if value is a CacheEntry suitable for external caching (not objects cache).""" return hasattr(value, 'outputs') and hasattr(value, 'ui') def _get_class_type(self, node_id): - """Get class_type for a node.""" if not self.initialized or not self.dynprompt: return '' try: @@ -331,7 +318,6 @@ class BasicCache: return '' def _build_context(self, node_id, cache_key): - """Build CacheContext with hash. Returns None if hashing fails.""" from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger try: cache_key_hash = _serialize_cache_key(cache_key) @@ -352,8 +338,8 @@ class BasicCache: subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) - subcache._is_subcache = True # Mark as subcache - excludes from external caching - subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID + subcache._is_subcache = True + subcache._current_prompt_id = self._current_prompt_id self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache diff --git a/execution.py b/execution.py index e5d1a4786..914b936bf 100644 --- a/execution.py +++ b/execution.py @@ -685,7 +685,6 @@ class PromptExecutor: self.add_message("execution_error", mes, broadcast=False) def _notify_prompt_lifecycle(self, event: str, prompt_id: str): - """Notify external cache providers of prompt lifecycle events.""" from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger if not _has_cache_providers(): @@ -716,11 +715,9 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - # Set prompt ID on caches for external provider integration for cache in self.caches.all: cache._current_prompt_id = prompt_id - # Notify external cache providers of prompt start self._notify_prompt_lifecycle("start", prompt_id) try: @@ -785,7 +782,6 @@ class PromptExecutor: if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() finally: - # Notify external cache providers of prompt end self._notify_prompt_lifecycle("end", prompt_id)