style: align documentation with codebase conventions

Strip verbose docstrings and section banners to match existing minimal
documentation style used throughout the codebase.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-03 17:24:07 -08:00
parent 311a2d59e4
commit 26f34d8642
4 changed files with 22 additions and 181 deletions

View File

@ -118,25 +118,6 @@ class Types:
class Caching: class Caching:
"""
External cache provider API.
Enables sharing cached results across multiple ComfyUI instances.
Example usage:
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
# Check external storage for cached result
...
async def on_store(self, context, value):
# Store result to external storage
...
Caching.register_provider(MyCacheProvider())
"""
# Import from comfy_execution.cache_provider (source of truth) # Import from comfy_execution.cache_provider (source of truth)
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
CacheProvider, CacheProvider,

View File

@ -1,29 +1,3 @@
"""
External Cache Provider API.
This module provides a public API for external cache providers, enabling
shared caching across multiple ComfyUI instances.
Public API is also available via:
from comfy_api.latest import Caching
Example usage:
from comfy_execution.cache_provider import (
CacheProvider, CacheContext, CacheValue, register_cache_provider
)
class MyCacheProvider(CacheProvider):
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
# Check external storage for cached result
...
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store result to external storage
...
register_cache_provider(MyCacheProvider())
"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List from typing import Any, Optional, Tuple, List
from dataclasses import dataclass from dataclasses import dataclass
@ -37,134 +11,64 @@ import threading
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# ============================================================
# Data Classes
# ============================================================
@dataclass @dataclass
class CacheContext: class CacheContext:
"""Context passed to provider methods.""" prompt_id: str
prompt_id: str # Current prompt execution ID node_id: str
node_id: str # Node being cached class_type: str
class_type: str # Node class type (e.g., "KSampler") cache_key_hash: str # SHA256 hex digest
cache_key_hash: str # SHA256 hex digest for external storage key
@dataclass @dataclass
class CacheValue: class CacheValue:
""" outputs: list
Value stored/retrieved from external cache. ui: dict = None
The ui field is optional - implementations may choose to skip it
(e.g., if it contains non-portable data like local file paths).
"""
outputs: list # The tensor/value outputs
ui: dict = None # Optional UI data (may be skipped by implementations)
# ============================================================
# Provider Interface
# ============================================================
class CacheProvider(ABC): class CacheProvider(ABC):
""" """Abstract base class for external cache providers.
Abstract base class for external cache providers. Exceptions from provider methods are caught by the caller and never break execution.
Async Safety:
Provider methods are called from async context. Implementations
can use async I/O directly.
Error Handling:
All methods are wrapped in try/except by the caller. Exceptions
are logged but never propagate to break execution.
Performance Guidelines:
- on_lookup: Should complete quickly (including any network I/O)
- on_store: Dispatched via asyncio.create_task (non-blocking)
- should_cache: Should be fast, called frequently
""" """
@abstractmethod @abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
""" """Called on local cache miss. Return CacheValue if found, None otherwise."""
Check external storage for cached result.
Called AFTER local cache miss (local-first for performance).
Returns:
CacheValue if found externally, None otherwise.
Important:
- Return None on any error (don't raise)
- Validate data integrity before returning
"""
pass pass
@abstractmethod @abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None: async def on_store(self, context: CacheContext, value: CacheValue) -> None:
""" """Called after local store. Dispatched via asyncio.create_task."""
Store value to external cache.
Called AFTER value is stored in local cache.
Dispatched via asyncio.create_task (non-blocking).
Important:
- Should not block execution
- Handle serialization failures gracefully
"""
pass pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
""" """Return False to skip external caching for this node. Default: True."""
Filter which nodes should be externally cached.
Called before on_lookup (value=None) and on_store (value provided).
Return False to skip external caching for this node.
Implementations can filter based on context.class_type, value size,
or any custom logic. Use _estimate_value_size() to get value size.
Default: Returns True (cache everything).
"""
return True return True
def on_prompt_start(self, prompt_id: str) -> None: def on_prompt_start(self, prompt_id: str) -> None:
"""Called when prompt execution begins. Optional."""
pass pass
def on_prompt_end(self, prompt_id: str) -> None: def on_prompt_end(self, prompt_id: str) -> None:
"""Called when prompt execution ends. Optional."""
pass pass
# ============================================================
# Provider Registry
# ============================================================
_providers: List[CacheProvider] = [] _providers: List[CacheProvider] = []
_providers_lock = threading.Lock() _providers_lock = threading.Lock()
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None _providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
def register_cache_provider(provider: CacheProvider) -> None: def register_cache_provider(provider: CacheProvider) -> None:
""" """Register an external cache provider. Providers are called in registration order."""
Register an external cache provider.
Providers are called in registration order. First provider to return
a result from on_lookup wins.
"""
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
if provider in _providers: if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered") _logger.warning(f"Provider {provider.__class__.__name__} already registered")
return return
_providers.append(provider) _providers.append(provider)
_providers_snapshot = None # Invalidate cache _providers_snapshot = None
_logger.info(f"Registered cache provider: {provider.__class__.__name__}") _logger.info(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None: def unregister_cache_provider(provider: CacheProvider) -> None:
"""Remove a previously registered provider."""
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
try: try:
@ -176,7 +80,6 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
def _get_cache_providers() -> Tuple[CacheProvider, ...]: def _get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance). Internal."""
global _providers_snapshot global _providers_snapshot
snapshot = _providers_snapshot snapshot = _providers_snapshot
if snapshot is not None: if snapshot is not None:
@ -189,30 +92,19 @@ def _get_cache_providers() -> Tuple[CacheProvider, ...]:
def _has_cache_providers() -> bool: def _has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock). Internal."""
return bool(_providers) return bool(_providers)
def _clear_cache_providers() -> None: def _clear_cache_providers() -> None:
"""Remove all providers. Useful for testing. Internal."""
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
_providers.clear() _providers.clear()
_providers_snapshot = None _providers_snapshot = None
# ============================================================
# Internal Utilities
# ============================================================
def _canonicalize(obj: Any) -> Any: def _canonicalize(obj: Any) -> Any:
""" # Convert to canonical JSON-serializable form with deterministic ordering.
Convert an object to a canonical, JSON-serializable form. # Frozensets have non-deterministic iteration order between Python sessions.
This ensures deterministic ordering regardless of Python's hash randomization.
Frozensets in particular have non-deterministic iteration order between
Python sessions, so consistent serialization requires explicit sorting.
"""
if isinstance(obj, frozenset): if isinstance(obj, frozenset):
# Sort frozenset items for deterministic ordering # Sort frozenset items for deterministic ordering
return ("__frozenset__", sorted( return ("__frozenset__", sorted(
@ -243,17 +135,8 @@ def _canonicalize(obj: Any) -> Any:
def _serialize_cache_key(cache_key: Any) -> Optional[str]: def _serialize_cache_key(cache_key: Any) -> Optional[str]:
""" # Returns deterministic SHA256 hex digest, or None on failure.
Serialize cache key to a hex digest string for external storage. # Uses JSON (not pickle) because pickle is non-deterministic across sessions.
Returns SHA256 hex string suitable as an external storage key,
or None if serialization fails entirely (fail-closed).
Note: Uses canonicalize + JSON serialization instead of pickle because
pickle is NOT deterministic across Python sessions due to hash randomization
affecting frozenset iteration order. Consistent hashing is required so that
different instances compute the same key for identical inputs.
"""
try: try:
canonical = _canonicalize(cache_key) canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
@ -270,12 +153,8 @@ def _serialize_cache_key(cache_key: Any) -> Optional[str]:
def _contains_nan(obj: Any) -> bool: def _contains_nan(obj: Any) -> bool:
""" # NaN != NaN so local cache never hits, but serialized NaN would match.
Check if cache key contains NaN (indicates uncacheable node). # Skip external caching for keys containing NaN.
NaN != NaN in Python, so local cache never hits. But serialized
NaN would match, causing incorrect external hits. Must skip these.
"""
if isinstance(obj, float): if isinstance(obj, float):
try: try:
return math.isnan(obj) return math.isnan(obj)
@ -296,7 +175,6 @@ def _contains_nan(obj: Any) -> bool:
def _estimate_value_size(value: CacheValue) -> int: def _estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering. Internal."""
try: try:
import torch import torch
except ImportError: except ImportError:

View File

@ -156,7 +156,6 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
# External cache provider support
self._is_subcache = False self._is_subcache = False
self._current_prompt_id = '' self._current_prompt_id = ''
@ -202,7 +201,6 @@ class BasicCache:
pass pass
def get_local(self, node_id): def get_local(self, node_id):
"""Sync local-only cache lookup (no external providers)."""
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
@ -211,7 +209,6 @@ class BasicCache:
return None return None
def set_local(self, node_id, value): def set_local(self, node_id, value):
"""Sync local-only cache store (no external providers)."""
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value self.cache[cache_key] = value
@ -221,7 +218,6 @@ class BasicCache:
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value self.cache[cache_key] = value
# Notify external providers
await self._notify_providers_store(node_id, cache_key, value) await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id): async def _get_immediate(self, node_id):
@ -229,26 +225,22 @@ class BasicCache:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
# Check local cache first (fast path)
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
# Check external providers on local miss
external_result = await self._check_providers_lookup(node_id, cache_key) external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None: if external_result is not None:
self.cache[cache_key] = external_result # Warm local cache self.cache[cache_key] = external_result
return external_result return external_result
return None return None
async def _notify_providers_store(self, node_id, cache_key, value): async def _notify_providers_store(self, node_id, cache_key, value):
"""Notify external providers of cache store (non-blocking)."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheValue, _contains_nan, _logger CacheValue, _contains_nan, _logger
) )
# Fast exit conditions
if self._is_subcache: if self._is_subcache:
return return
if not _has_cache_providers(): if not _has_cache_providers():
@ -272,7 +264,6 @@ class BasicCache:
@staticmethod @staticmethod
async def _safe_provider_store(provider, context, cache_value): async def _safe_provider_store(provider, context, cache_value):
"""Wrapper for async provider.on_store with error handling."""
from comfy_execution.cache_provider import _logger from comfy_execution.cache_provider import _logger
try: try:
await provider.on_store(context, cache_value) await provider.on_store(context, cache_value)
@ -280,7 +271,6 @@ class BasicCache:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key): async def _check_providers_lookup(self, node_id, cache_key):
"""Check external providers for cached result."""
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _has_cache_providers, _get_cache_providers,
CacheValue, _contains_nan, _logger CacheValue, _contains_nan, _logger
@ -309,7 +299,6 @@ class BasicCache:
if not isinstance(result.outputs, (list, tuple)): if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue continue
# Import CacheEntry here to avoid circular import at module level
from execution import CacheEntry from execution import CacheEntry
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
except Exception as e: except Exception as e:
@ -318,11 +307,9 @@ class BasicCache:
return None return None
def _is_external_cacheable_value(self, value): def _is_external_cacheable_value(self, value):
"""Check if value is a CacheEntry suitable for external caching (not objects cache)."""
return hasattr(value, 'outputs') and hasattr(value, 'ui') return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id): def _get_class_type(self, node_id):
"""Get class_type for a node."""
if not self.initialized or not self.dynprompt: if not self.initialized or not self.dynprompt:
return '' return ''
try: try:
@ -331,7 +318,6 @@ class BasicCache:
return '' return ''
def _build_context(self, node_id, cache_key): def _build_context(self, node_id, cache_key):
"""Build CacheContext with hash. Returns None if hashing fails."""
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try: try:
cache_key_hash = _serialize_cache_key(cache_key) cache_key_hash = _serialize_cache_key(cache_key)
@ -352,8 +338,8 @@ class BasicCache:
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
subcache._is_subcache = True # Mark as subcache - excludes from external caching subcache._is_subcache = True
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID subcache._current_prompt_id = self._current_prompt_id
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache return subcache

View File

@ -685,7 +685,6 @@ class PromptExecutor:
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str): def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
"""Notify external cache providers of prompt lifecycle events."""
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger
if not _has_cache_providers(): if not _has_cache_providers():
@ -716,11 +715,9 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
# Set prompt ID on caches for external provider integration
for cache in self.caches.all: for cache in self.caches.all:
cache._current_prompt_id = prompt_id cache._current_prompt_id = prompt_id
# Notify external cache providers of prompt start
self._notify_prompt_lifecycle("start", prompt_id) self._notify_prompt_lifecycle("start", prompt_id)
try: try:
@ -785,7 +782,6 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
finally: finally:
# Notify external cache providers of prompt end
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)