diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 1a1b6d162..5df1b4c2f 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -118,12 +118,12 @@ class Caching: from comfy_api.latest import Caching class MyRedisProvider(Caching.CacheProvider): - def on_lookup(self, context): + async def on_lookup(self, context): # Check Redis for cached result ... - def on_store(self, context, value): - # Store to Redis (can be async internally) + async def on_store(self, context, value): + # Store to Redis ... Caching.register_provider(MyRedisProvider()) @@ -135,10 +135,6 @@ class Caching: CacheValue, register_cache_provider as register_provider, unregister_cache_provider as unregister_provider, - get_cache_providers as get_providers, - has_cache_providers as has_providers, - clear_cache_providers as clear_providers, - estimate_value_size, ) diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py index a00e8cb15..79a5b71e8 100644 --- a/comfy_execution/cache_provider.py +++ b/comfy_execution/cache_provider.py @@ -13,12 +13,12 @@ Example usage: ) class MyRedisProvider(CacheProvider): - def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: # Check Redis/GCS for cached result ... - def on_store(self, context: CacheContext, value: CacheValue) -> None: - # Store to Redis/GCS (can be async internally) + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + # Store to Redis/GCS ... register_cache_provider(MyRedisProvider()) @@ -34,7 +34,7 @@ import math import pickle import threading -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) # ============================================================ @@ -47,8 +47,7 @@ class CacheContext: prompt_id: str # Current prompt execution ID node_id: str # Node being cached class_type: str # Node class type (e.g., "KSampler") - cache_key: Any # Raw cache key (frozenset structure) - cache_key_bytes: bytes # SHA256 hash for external storage key + cache_key_hash: str # SHA256 hex digest for external storage key @dataclass @@ -71,9 +70,9 @@ class CacheProvider(ABC): """ Abstract base class for external cache providers. - Thread Safety: - Providers may be called from multiple threads. Implementations - must be thread-safe. + Async Safety: + Provider methods are called from async context. Implementations + can use async I/O (aiohttp, asyncpg, etc.) directly. Error Handling: All methods are wrapped in try/except by the caller. Exceptions @@ -81,12 +80,12 @@ class CacheProvider(ABC): Performance Guidelines: - on_lookup: Should complete in <500ms (including network) - - on_store: Can be async internally (fire-and-forget) + - on_store: Fire-and-forget via asyncio.create_task - should_cache: Should be fast (<1ms), called frequently """ @abstractmethod - def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: """ Check external storage for cached result. @@ -102,14 +101,14 @@ class CacheProvider(ABC): pass @abstractmethod - def on_store(self, context: CacheContext, value: CacheValue) -> None: + async def on_store(self, context: CacheContext, value: CacheValue) -> None: """ Store value to external cache. Called AFTER value is stored in local cache. + Dispatched as asyncio.create_task (fire-and-forget). Important: - - Can be fire-and-forget (async internally) - Should never block execution - Handle serialization failures gracefully """ @@ -123,7 +122,7 @@ class CacheProvider(ABC): Return False to skip external caching for this node. Implementations can filter based on context.class_type, value size, - or any custom logic. Use estimate_value_size() to get value size. + or any custom logic. Use _estimate_value_size() to get value size. Default: Returns True (cache everything). """ @@ -157,11 +156,11 @@ def register_cache_provider(provider: CacheProvider) -> None: global _providers_snapshot with _providers_lock: if provider in _providers: - logger.warning(f"Provider {provider.__class__.__name__} already registered") + _logger.warning(f"Provider {provider.__class__.__name__} already registered") return _providers.append(provider) _providers_snapshot = None # Invalidate cache - logger.info(f"Registered cache provider: {provider.__class__.__name__}") + _logger.info(f"Registered cache provider: {provider.__class__.__name__}") def unregister_cache_provider(provider: CacheProvider) -> None: @@ -171,13 +170,13 @@ def unregister_cache_provider(provider: CacheProvider) -> None: try: _providers.remove(provider) _providers_snapshot = None - logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") + _logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") except ValueError: - logger.warning(f"Provider {provider.__class__.__name__} was not registered") + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") -def get_cache_providers() -> Tuple[CacheProvider, ...]: - """Get registered providers (cached for performance).""" +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + """Get registered providers (cached for performance). Internal.""" global _providers_snapshot snapshot = _providers_snapshot if snapshot is not None: @@ -189,13 +188,13 @@ def get_cache_providers() -> Tuple[CacheProvider, ...]: return _providers_snapshot -def has_cache_providers() -> bool: - """Fast check if any providers registered (no lock).""" +def _has_cache_providers() -> bool: + """Fast check if any providers registered (no lock). Internal.""" return bool(_providers) -def clear_cache_providers() -> None: - """Remove all providers. Useful for testing.""" +def _clear_cache_providers() -> None: + """Remove all providers. Useful for testing. Internal.""" global _providers_snapshot with _providers_lock: _providers.clear() @@ -203,7 +202,7 @@ def clear_cache_providers() -> None: # ============================================================ -# Utilities +# Internal Utilities # ============================================================ def _canonicalize(obj: Any) -> Any: @@ -243,11 +242,11 @@ def _canonicalize(obj: Any) -> Any: return ("__repr__", repr(obj)) -def serialize_cache_key(cache_key: Any) -> bytes: +def _serialize_cache_key(cache_key: Any) -> str: """ - Serialize cache key to bytes for external storage. + Serialize cache key to a hex digest string for external storage. - Returns SHA256 hash suitable for Redis/database keys. + Returns SHA256 hex string suitable for Redis/database keys. Note: Uses canonicalize + JSON serialization instead of pickle because pickle is NOT deterministic across Python sessions due to hash randomization @@ -257,18 +256,18 @@ def serialize_cache_key(cache_key: Any) -> bytes: try: canonical = _canonicalize(cache_key) json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) - return hashlib.sha256(json_str.encode('utf-8')).digest() + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() except Exception as e: - logger.warning(f"Failed to serialize cache key: {e}") + _logger.warning(f"Failed to serialize cache key: {e}") # Fallback to pickle (non-deterministic but better than nothing) try: serialized = pickle.dumps(cache_key, protocol=4) - return hashlib.sha256(serialized).digest() + return hashlib.sha256(serialized).hexdigest() except Exception: - return hashlib.sha256(str(id(cache_key)).encode()).digest() + return hashlib.sha256(str(id(cache_key)).encode()).hexdigest() -def contains_nan(obj: Any) -> bool: +def _contains_nan(obj: Any) -> bool: """ Check if cache key contains NaN (indicates uncacheable node). @@ -288,14 +287,14 @@ def contains_nan(obj: Any) -> bool: except (TypeError, ValueError): return False if isinstance(obj, (frozenset, tuple, list, set)): - return any(contains_nan(item) for item in obj) + return any(_contains_nan(item) for item in obj) if isinstance(obj, dict): - return any(contains_nan(k) or contains_nan(v) for k, v in obj.items()) + return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items()) return False -def estimate_value_size(value: CacheValue) -> int: - """Estimate serialized size in bytes. Useful for size-based filtering.""" +def _estimate_value_size(value: CacheValue) -> int: + """Estimate serialized size in bytes. Useful for size-based filtering. Internal.""" try: import torch except ImportError: diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index c666c4dc1..63edb9ad0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -200,15 +201,15 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): + async def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value # Notify external providers - self._notify_providers_store(node_id, cache_key, value) + await self._notify_providers_store(node_id, cache_key, value) - def _get_immediate(self, node_id): + async def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) @@ -218,87 +219,88 @@ class BasicCache: return self.cache[cache_key] # Check external providers on local miss - external_result = self._check_providers_lookup(node_id, cache_key) + external_result = await self._check_providers_lookup(node_id, cache_key) if external_result is not None: self.cache[cache_key] = external_result # Warm local cache return external_result return None - def _notify_providers_store(self, node_id, cache_key, value): - """Notify external providers of cache store.""" + async def _notify_providers_store(self, node_id, cache_key, value): + """Notify external providers of cache store (fire-and-forget).""" from comfy_execution.cache_provider import ( - has_cache_providers, get_cache_providers, + _has_cache_providers, _get_cache_providers, CacheContext, CacheValue, - serialize_cache_key, contains_nan, logger + _serialize_cache_key, _contains_nan, _logger ) # Fast exit conditions if self._is_subcache: return - if not has_cache_providers(): + if not _has_cache_providers(): return if not self._is_external_cacheable_value(value): return - if contains_nan(cache_key): + if _contains_nan(cache_key): return - context = CacheContext( - prompt_id=self._current_prompt_id, - node_id=node_id, - class_type=self._get_class_type(node_id), - cache_key=cache_key, - cache_key_bytes=serialize_cache_key(cache_key) - ) + context = self._build_context(node_id, cache_key) + if context is None: + return cache_value = CacheValue(outputs=value.outputs, ui=value.ui) - for provider in get_cache_providers(): + for provider in _get_cache_providers(): try: if provider.should_cache(context, cache_value): - provider.on_store(context, cache_value) + asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) except Exception as e: - logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") - def _check_providers_lookup(self, node_id, cache_key): + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + """Wrapper for fire-and-forget provider.on_store with error handling.""" + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): """Check external providers for cached result.""" from comfy_execution.cache_provider import ( - has_cache_providers, get_cache_providers, + _has_cache_providers, _get_cache_providers, CacheContext, CacheValue, - serialize_cache_key, contains_nan, logger + _contains_nan, _logger ) if self._is_subcache: return None - if not has_cache_providers(): + if not _has_cache_providers(): return None - if contains_nan(cache_key): + if _contains_nan(cache_key): return None - context = CacheContext( - prompt_id=self._current_prompt_id, - node_id=node_id, - class_type=self._get_class_type(node_id), - cache_key=cache_key, - cache_key_bytes=serialize_cache_key(cache_key) - ) + context = self._build_context(node_id, cache_key) + if context is None: + return None - for provider in get_cache_providers(): + for provider in _get_cache_providers(): try: if not provider.should_cache(context): continue - result = provider.on_lookup(context) + result = await provider.on_lookup(context) if result is not None: if not isinstance(result, CacheValue): - logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") continue if not isinstance(result.outputs, (list, tuple)): - logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") continue # Import CacheEntry here to avoid circular import at module level from execution import CacheEntry return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) except Exception as e: - logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") return None @@ -315,6 +317,16 @@ class BasicCache: except Exception: return '' + def _build_context(self, node_id, cache_key): + """Build CacheContext with hash. Returns None if hashing fails on NaN.""" + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key + return CacheContext( + prompt_id=self._current_prompt_id, + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=_serialize_cache_key(cache_key) + ) + async def _ensure_subcache(self, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) @@ -364,16 +376,16 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -394,10 +406,10 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + async def set(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): @@ -429,18 +441,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -480,13 +492,13 @@ class RAMPressureCache(LRUCache): def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/execution.py b/execution.py index 204339af8..b4c5deabd 100644 --- a/execution.py +++ b/execution.py @@ -414,7 +414,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) + cached = await caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -470,10 +470,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = caches.objects.get(unique_id) + obj = await caches.objects.get(unique_id) if obj is None: obj = class_def() - caches.objects.set(unique_id, obj) + await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -575,7 +575,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - caches.outputs.set(unique_id, cache_entry) + await caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -720,7 +720,7 @@ class PromptExecutor: cached_nodes = [] for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: + if await self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) comfy.model_management.cleanup_models_gc()