refactor: async CacheProvider API + reduce public surface

- Make on_lookup/on_store async on CacheProvider ABC
- Simplify CacheContext: replace cache_key + cache_key_bytes with
  cache_key_hash (str hex digest)
- Make registry/utility functions internal (_prefix)
- Trim comfy_api.latest.Caching exports to core API only
- Make cache get/set async throughout caching.py hierarchy
- Use asyncio.create_task for fire-and-forget on_store
- Add NaN gating before provider calls in Core
- Add await to 5 cache call sites in execution.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-03 12:34:25 -08:00
parent 0141af0786
commit 4cbe4fe4c7
4 changed files with 107 additions and 100 deletions

View File

@ -118,12 +118,12 @@ class Caching:
from comfy_api.latest import Caching from comfy_api.latest import Caching
class MyRedisProvider(Caching.CacheProvider): class MyRedisProvider(Caching.CacheProvider):
def on_lookup(self, context): async def on_lookup(self, context):
# Check Redis for cached result # Check Redis for cached result
... ...
def on_store(self, context, value): async def on_store(self, context, value):
# Store to Redis (can be async internally) # Store to Redis
... ...
Caching.register_provider(MyRedisProvider()) Caching.register_provider(MyRedisProvider())
@ -135,10 +135,6 @@ class Caching:
CacheValue, CacheValue,
register_cache_provider as register_provider, register_cache_provider as register_provider,
unregister_cache_provider as unregister_provider, unregister_cache_provider as unregister_provider,
get_cache_providers as get_providers,
has_cache_providers as has_providers,
clear_cache_providers as clear_providers,
estimate_value_size,
) )

View File

@ -13,12 +13,12 @@ Example usage:
) )
class MyRedisProvider(CacheProvider): class MyRedisProvider(CacheProvider):
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
# Check Redis/GCS for cached result # Check Redis/GCS for cached result
... ...
def on_store(self, context: CacheContext, value: CacheValue) -> None: async def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store to Redis/GCS (can be async internally) # Store to Redis/GCS
... ...
register_cache_provider(MyRedisProvider()) register_cache_provider(MyRedisProvider())
@ -34,7 +34,7 @@ import math
import pickle import pickle
import threading import threading
logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# ============================================================ # ============================================================
@ -47,8 +47,7 @@ class CacheContext:
prompt_id: str # Current prompt execution ID prompt_id: str # Current prompt execution ID
node_id: str # Node being cached node_id: str # Node being cached
class_type: str # Node class type (e.g., "KSampler") class_type: str # Node class type (e.g., "KSampler")
cache_key: Any # Raw cache key (frozenset structure) cache_key_hash: str # SHA256 hex digest for external storage key
cache_key_bytes: bytes # SHA256 hash for external storage key
@dataclass @dataclass
@ -71,9 +70,9 @@ class CacheProvider(ABC):
""" """
Abstract base class for external cache providers. Abstract base class for external cache providers.
Thread Safety: Async Safety:
Providers may be called from multiple threads. Implementations Provider methods are called from async context. Implementations
must be thread-safe. can use async I/O (aiohttp, asyncpg, etc.) directly.
Error Handling: Error Handling:
All methods are wrapped in try/except by the caller. Exceptions All methods are wrapped in try/except by the caller. Exceptions
@ -81,12 +80,12 @@ class CacheProvider(ABC):
Performance Guidelines: Performance Guidelines:
- on_lookup: Should complete in <500ms (including network) - on_lookup: Should complete in <500ms (including network)
- on_store: Can be async internally (fire-and-forget) - on_store: Fire-and-forget via asyncio.create_task
- should_cache: Should be fast (<1ms), called frequently - should_cache: Should be fast (<1ms), called frequently
""" """
@abstractmethod @abstractmethod
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
""" """
Check external storage for cached result. Check external storage for cached result.
@ -102,14 +101,14 @@ class CacheProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def on_store(self, context: CacheContext, value: CacheValue) -> None: async def on_store(self, context: CacheContext, value: CacheValue) -> None:
""" """
Store value to external cache. Store value to external cache.
Called AFTER value is stored in local cache. Called AFTER value is stored in local cache.
Dispatched as asyncio.create_task (fire-and-forget).
Important: Important:
- Can be fire-and-forget (async internally)
- Should never block execution - Should never block execution
- Handle serialization failures gracefully - Handle serialization failures gracefully
""" """
@ -123,7 +122,7 @@ class CacheProvider(ABC):
Return False to skip external caching for this node. Return False to skip external caching for this node.
Implementations can filter based on context.class_type, value size, Implementations can filter based on context.class_type, value size,
or any custom logic. Use estimate_value_size() to get value size. or any custom logic. Use _estimate_value_size() to get value size.
Default: Returns True (cache everything). Default: Returns True (cache everything).
""" """
@ -157,11 +156,11 @@ def register_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
if provider in _providers: if provider in _providers:
logger.warning(f"Provider {provider.__class__.__name__} already registered") _logger.warning(f"Provider {provider.__class__.__name__} already registered")
return return
_providers.append(provider) _providers.append(provider)
_providers_snapshot = None # Invalidate cache _providers_snapshot = None # Invalidate cache
logger.info(f"Registered cache provider: {provider.__class__.__name__}") _logger.info(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None: def unregister_cache_provider(provider: CacheProvider) -> None:
@ -171,13 +170,13 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
try: try:
_providers.remove(provider) _providers.remove(provider)
_providers_snapshot = None _providers_snapshot = None
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") _logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError: except ValueError:
logger.warning(f"Provider {provider.__class__.__name__} was not registered") _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def get_cache_providers() -> Tuple[CacheProvider, ...]: def _get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance).""" """Get registered providers (cached for performance). Internal."""
global _providers_snapshot global _providers_snapshot
snapshot = _providers_snapshot snapshot = _providers_snapshot
if snapshot is not None: if snapshot is not None:
@ -189,13 +188,13 @@ def get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot return _providers_snapshot
def has_cache_providers() -> bool: def _has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock).""" """Fast check if any providers registered (no lock). Internal."""
return bool(_providers) return bool(_providers)
def clear_cache_providers() -> None: def _clear_cache_providers() -> None:
"""Remove all providers. Useful for testing.""" """Remove all providers. Useful for testing. Internal."""
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
_providers.clear() _providers.clear()
@ -203,7 +202,7 @@ def clear_cache_providers() -> None:
# ============================================================ # ============================================================
# Utilities # Internal Utilities
# ============================================================ # ============================================================
def _canonicalize(obj: Any) -> Any: def _canonicalize(obj: Any) -> Any:
@ -243,11 +242,11 @@ def _canonicalize(obj: Any) -> Any:
return ("__repr__", repr(obj)) return ("__repr__", repr(obj))
def serialize_cache_key(cache_key: Any) -> bytes: def _serialize_cache_key(cache_key: Any) -> str:
""" """
Serialize cache key to bytes for external storage. Serialize cache key to a hex digest string for external storage.
Returns SHA256 hash suitable for Redis/database keys. Returns SHA256 hex string suitable for Redis/database keys.
Note: Uses canonicalize + JSON serialization instead of pickle because Note: Uses canonicalize + JSON serialization instead of pickle because
pickle is NOT deterministic across Python sessions due to hash randomization pickle is NOT deterministic across Python sessions due to hash randomization
@ -257,18 +256,18 @@ def serialize_cache_key(cache_key: Any) -> bytes:
try: try:
canonical = _canonicalize(cache_key) canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).digest() return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e: except Exception as e:
logger.warning(f"Failed to serialize cache key: {e}") _logger.warning(f"Failed to serialize cache key: {e}")
# Fallback to pickle (non-deterministic but better than nothing) # Fallback to pickle (non-deterministic but better than nothing)
try: try:
serialized = pickle.dumps(cache_key, protocol=4) serialized = pickle.dumps(cache_key, protocol=4)
return hashlib.sha256(serialized).digest() return hashlib.sha256(serialized).hexdigest()
except Exception: except Exception:
return hashlib.sha256(str(id(cache_key)).encode()).digest() return hashlib.sha256(str(id(cache_key)).encode()).hexdigest()
def contains_nan(obj: Any) -> bool: def _contains_nan(obj: Any) -> bool:
""" """
Check if cache key contains NaN (indicates uncacheable node). Check if cache key contains NaN (indicates uncacheable node).
@ -288,14 +287,14 @@ def contains_nan(obj: Any) -> bool:
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
if isinstance(obj, (frozenset, tuple, list, set)): if isinstance(obj, (frozenset, tuple, list, set)):
return any(contains_nan(item) for item in obj) return any(_contains_nan(item) for item in obj)
if isinstance(obj, dict): if isinstance(obj, dict):
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items()) return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items())
return False return False
def estimate_value_size(value: CacheValue) -> int: def _estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering.""" """Estimate serialized size in bytes. Useful for size-based filtering. Internal."""
try: try:
import torch import torch
except ImportError: except ImportError:

View File

@ -1,3 +1,4 @@
import asyncio
import bisect import bisect
import gc import gc
import itertools import itertools
@ -200,15 +201,15 @@ class BasicCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def _set_immediate(self, node_id, value): async def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value self.cache[cache_key] = value
# Notify external providers # Notify external providers
self._notify_providers_store(node_id, cache_key, value) await self._notify_providers_store(node_id, cache_key, value)
def _get_immediate(self, node_id): async def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
@ -218,87 +219,88 @@ class BasicCache:
return self.cache[cache_key] return self.cache[cache_key]
# Check external providers on local miss # Check external providers on local miss
external_result = self._check_providers_lookup(node_id, cache_key) external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None: if external_result is not None:
self.cache[cache_key] = external_result # Warm local cache self.cache[cache_key] = external_result # Warm local cache
return external_result return external_result
return None return None
def _notify_providers_store(self, node_id, cache_key, value): async def _notify_providers_store(self, node_id, cache_key, value):
"""Notify external providers of cache store.""" """Notify external providers of cache store (fire-and-forget)."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheContext, CacheValue, CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger _serialize_cache_key, _contains_nan, _logger
) )
# Fast exit conditions # Fast exit conditions
if self._is_subcache: if self._is_subcache:
return return
if not has_cache_providers(): if not _has_cache_providers():
return return
if not self._is_external_cacheable_value(value): if not self._is_external_cacheable_value(value):
return return
if contains_nan(cache_key): if _contains_nan(cache_key):
return return
context = CacheContext( context = self._build_context(node_id, cache_key)
prompt_id=self._current_prompt_id, if context is None:
node_id=node_id, return
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
cache_value = CacheValue(outputs=value.outputs, ui=value.ui) cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in get_cache_providers(): for provider in _get_cache_providers():
try: try:
if provider.should_cache(context, cache_value): if provider.should_cache(context, cache_value):
provider.on_store(context, cache_value) asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
except Exception as e: except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
def _check_providers_lookup(self, node_id, cache_key): @staticmethod
async def _safe_provider_store(provider, context, cache_value):
"""Wrapper for fire-and-forget provider.on_store with error handling."""
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
"""Check external providers for cached result.""" """Check external providers for cached result."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheContext, CacheValue, CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger _contains_nan, _logger
) )
if self._is_subcache: if self._is_subcache:
return None return None
if not has_cache_providers(): if not _has_cache_providers():
return None return None
if contains_nan(cache_key): if _contains_nan(cache_key):
return None return None
context = CacheContext( context = self._build_context(node_id, cache_key)
prompt_id=self._current_prompt_id, if context is None:
node_id=node_id, return None
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
for provider in get_cache_providers(): for provider in _get_cache_providers():
try: try:
if not provider.should_cache(context): if not provider.should_cache(context):
continue continue
result = provider.on_lookup(context) result = await provider.on_lookup(context)
if result is not None: if result is not None:
if not isinstance(result, CacheValue): if not isinstance(result, CacheValue):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue continue
if not isinstance(result.outputs, (list, tuple)): if not isinstance(result.outputs, (list, tuple)):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue continue
# Import CacheEntry here to avoid circular import at module level # Import CacheEntry here to avoid circular import at module level
from execution import CacheEntry from execution import CacheEntry
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
except Exception as e: except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None return None
@ -315,6 +317,16 @@ class BasicCache:
except Exception: except Exception:
return '' return ''
def _build_context(self, node_id, cache_key):
"""Build CacheContext with hash. Returns None if hashing fails on NaN."""
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key
return CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=_serialize_cache_key(cache_key)
)
async def _ensure_subcache(self, node_id, children_ids): async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
@ -364,16 +376,16 @@ class HierarchicalCache(BasicCache):
return None return None
return cache return cache
def get(self, node_id): async def get(self, node_id):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
if cache is None: if cache is None:
return None return None
return cache._get_immediate(node_id) return await cache._get_immediate(node_id)
def set(self, node_id, value): async def set(self, node_id, value):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
cache._set_immediate(node_id, value) await cache._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
@ -394,10 +406,10 @@ class NullCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def get(self, node_id): async def get(self, node_id):
return None return None
def set(self, node_id, value): async def set(self, node_id, value):
pass pass
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
@ -429,18 +441,18 @@ class LRUCache(BasicCache):
del self.children[key] del self.children[key]
self._clean_subcaches() self._clean_subcaches()
def get(self, node_id): async def get(self, node_id):
self._mark_used(node_id) self._mark_used(node_id)
return self._get_immediate(node_id) return await self._get_immediate(node_id)
def _mark_used(self, node_id): def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None: if cache_key is not None:
self.used_generation[cache_key] = self.generation self.used_generation[cache_key] = self.generation
def set(self, node_id, value): async def set(self, node_id, value):
self._mark_used(node_id) self._mark_used(node_id)
return self._set_immediate(node_id, value) return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
@ -480,13 +492,13 @@ class RAMPressureCache(LRUCache):
def clean_unused(self): def clean_unused(self):
self._clean_subcaches() self._clean_subcaches()
def set(self, node_id, value): async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value) await super().set(node_id, value)
def get(self, node_id): async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id) return await super().get(node_id)
def poll(self, ram_headroom): def poll(self, ram_headroom):
def _ram_gb(): def _ram_gb():

View File

@ -414,7 +414,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id) cached = await caches.outputs.get(unique_id)
if cached is not None: if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_ui = cached.ui or {} cached_ui = cached.ui or {}
@ -470,10 +470,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = await caches.objects.get(unique_id)
if obj is None: if obj is None:
obj = class_def() obj = class_def()
caches.objects.set(unique_id, obj) await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal): if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@ -575,7 +575,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry) await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -720,7 +720,7 @@ class PromptExecutor:
cached_nodes = [] cached_nodes = []
for node_id in prompt: for node_id in prompt:
if self.caches.outputs.get(node_id) is not None: if await self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id) cached_nodes.append(node_id)
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()