mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
style: align documentation with codebase conventions
Strip verbose docstrings and section banners to match existing minimal documentation style used throughout the codebase. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
311a2d59e4
commit
26f34d8642
@ -118,25 +118,6 @@ class Types:
|
||||
|
||||
|
||||
class Caching:
|
||||
"""
|
||||
External cache provider API.
|
||||
|
||||
Enables sharing cached results across multiple ComfyUI instances.
|
||||
|
||||
Example usage:
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
class MyCacheProvider(Caching.CacheProvider):
|
||||
async def on_lookup(self, context):
|
||||
# Check external storage for cached result
|
||||
...
|
||||
|
||||
async def on_store(self, context, value):
|
||||
# Store result to external storage
|
||||
...
|
||||
|
||||
Caching.register_provider(MyCacheProvider())
|
||||
"""
|
||||
# Import from comfy_execution.cache_provider (source of truth)
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider,
|
||||
|
||||
@ -1,29 +1,3 @@
|
||||
"""
|
||||
External Cache Provider API.
|
||||
|
||||
This module provides a public API for external cache providers, enabling
|
||||
shared caching across multiple ComfyUI instances.
|
||||
|
||||
Public API is also available via:
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
Example usage:
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider, CacheContext, CacheValue, register_cache_provider
|
||||
)
|
||||
|
||||
class MyCacheProvider(CacheProvider):
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
# Check external storage for cached result
|
||||
...
|
||||
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
# Store result to external storage
|
||||
...
|
||||
|
||||
register_cache_provider(MyCacheProvider())
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
@ -37,134 +11,64 @@ import threading
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data Classes
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
"""Context passed to provider methods."""
|
||||
prompt_id: str # Current prompt execution ID
|
||||
node_id: str # Node being cached
|
||||
class_type: str # Node class type (e.g., "KSampler")
|
||||
cache_key_hash: str # SHA256 hex digest for external storage key
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
class_type: str
|
||||
cache_key_hash: str # SHA256 hex digest
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
"""
|
||||
Value stored/retrieved from external cache.
|
||||
outputs: list
|
||||
ui: dict = None
|
||||
|
||||
The ui field is optional - implementations may choose to skip it
|
||||
(e.g., if it contains non-portable data like local file paths).
|
||||
"""
|
||||
outputs: list # The tensor/value outputs
|
||||
ui: dict = None # Optional UI data (may be skipped by implementations)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Interface
|
||||
# ============================================================
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""
|
||||
Abstract base class for external cache providers.
|
||||
|
||||
Async Safety:
|
||||
Provider methods are called from async context. Implementations
|
||||
can use async I/O directly.
|
||||
|
||||
Error Handling:
|
||||
All methods are wrapped in try/except by the caller. Exceptions
|
||||
are logged but never propagate to break execution.
|
||||
|
||||
Performance Guidelines:
|
||||
- on_lookup: Should complete quickly (including any network I/O)
|
||||
- on_store: Dispatched via asyncio.create_task (non-blocking)
|
||||
- should_cache: Should be fast, called frequently
|
||||
"""Abstract base class for external cache providers.
|
||||
Exceptions from provider methods are caught by the caller and never break execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""
|
||||
Check external storage for cached result.
|
||||
|
||||
Called AFTER local cache miss (local-first for performance).
|
||||
|
||||
Returns:
|
||||
CacheValue if found externally, None otherwise.
|
||||
|
||||
Important:
|
||||
- Return None on any error (don't raise)
|
||||
- Validate data integrity before returning
|
||||
"""
|
||||
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""
|
||||
Store value to external cache.
|
||||
|
||||
Called AFTER value is stored in local cache.
|
||||
Dispatched via asyncio.create_task (non-blocking).
|
||||
|
||||
Important:
|
||||
- Should not block execution
|
||||
- Handle serialization failures gracefully
|
||||
"""
|
||||
"""Called after local store. Dispatched via asyncio.create_task."""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""
|
||||
Filter which nodes should be externally cached.
|
||||
|
||||
Called before on_lookup (value=None) and on_store (value provided).
|
||||
Return False to skip external caching for this node.
|
||||
|
||||
Implementations can filter based on context.class_type, value size,
|
||||
or any custom logic. Use _estimate_value_size() to get value size.
|
||||
|
||||
Default: Returns True (cache everything).
|
||||
"""
|
||||
"""Return False to skip external caching for this node. Default: True."""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution begins. Optional."""
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution ends. Optional."""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Registry
|
||||
# ============================================================
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""
|
||||
Register an external cache provider.
|
||||
|
||||
Providers are called in registration order. First provider to return
|
||||
a result from on_lookup wins.
|
||||
"""
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = None # Invalidate cache
|
||||
_providers_snapshot = None
|
||||
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Remove a previously registered provider."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
@ -176,7 +80,6 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
|
||||
|
||||
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
"""Get registered providers (cached for performance). Internal."""
|
||||
global _providers_snapshot
|
||||
snapshot = _providers_snapshot
|
||||
if snapshot is not None:
|
||||
@ -189,30 +92,19 @@ def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
|
||||
|
||||
def _has_cache_providers() -> bool:
|
||||
"""Fast check if any providers registered (no lock). Internal."""
|
||||
return bool(_providers)
|
||||
|
||||
|
||||
def _clear_cache_providers() -> None:
|
||||
"""Remove all providers. Useful for testing. Internal."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Internal Utilities
|
||||
# ============================================================
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
"""
|
||||
Convert an object to a canonical, JSON-serializable form.
|
||||
|
||||
This ensures deterministic ordering regardless of Python's hash randomization.
|
||||
Frozensets in particular have non-deterministic iteration order between
|
||||
Python sessions, so consistent serialization requires explicit sorting.
|
||||
"""
|
||||
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||
if isinstance(obj, frozenset):
|
||||
# Sort frozenset items for deterministic ordering
|
||||
return ("__frozenset__", sorted(
|
||||
@ -243,17 +135,8 @@ def _canonicalize(obj: Any) -> Any:
|
||||
|
||||
|
||||
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||
"""
|
||||
Serialize cache key to a hex digest string for external storage.
|
||||
|
||||
Returns SHA256 hex string suitable as an external storage key,
|
||||
or None if serialization fails entirely (fail-closed).
|
||||
|
||||
Note: Uses canonicalize + JSON serialization instead of pickle because
|
||||
pickle is NOT deterministic across Python sessions due to hash randomization
|
||||
affecting frozenset iteration order. Consistent hashing is required so that
|
||||
different instances compute the same key for identical inputs.
|
||||
"""
|
||||
# Returns deterministic SHA256 hex digest, or None on failure.
|
||||
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
@ -270,12 +153,8 @@ def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||
|
||||
|
||||
def _contains_nan(obj: Any) -> bool:
|
||||
"""
|
||||
Check if cache key contains NaN (indicates uncacheable node).
|
||||
|
||||
NaN != NaN in Python, so local cache never hits. But serialized
|
||||
NaN would match, causing incorrect external hits. Must skip these.
|
||||
"""
|
||||
# NaN != NaN so local cache never hits, but serialized NaN would match.
|
||||
# Skip external caching for keys containing NaN.
|
||||
if isinstance(obj, float):
|
||||
try:
|
||||
return math.isnan(obj)
|
||||
@ -296,7 +175,6 @@ def _contains_nan(obj: Any) -> bool:
|
||||
|
||||
|
||||
def _estimate_value_size(value: CacheValue) -> int:
|
||||
"""Estimate serialized size in bytes. Useful for size-based filtering. Internal."""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
|
||||
@ -156,7 +156,6 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
# External cache provider support
|
||||
self._is_subcache = False
|
||||
self._current_prompt_id = ''
|
||||
|
||||
@ -202,7 +201,6 @@ class BasicCache:
|
||||
pass
|
||||
|
||||
def get_local(self, node_id):
|
||||
"""Sync local-only cache lookup (no external providers)."""
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@ -211,7 +209,6 @@ class BasicCache:
|
||||
return None
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
"""Sync local-only cache store (no external providers)."""
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
@ -221,7 +218,6 @@ class BasicCache:
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
# Notify external providers
|
||||
await self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
async def _get_immediate(self, node_id):
|
||||
@ -229,26 +225,22 @@ class BasicCache:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
# Check local cache first (fast path)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
# Check external providers on local miss
|
||||
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result # Warm local cache
|
||||
self.cache[cache_key] = external_result
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||
"""Notify external providers of cache store (non-blocking)."""
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_nan, _logger
|
||||
)
|
||||
|
||||
# Fast exit conditions
|
||||
if self._is_subcache:
|
||||
return
|
||||
if not _has_cache_providers():
|
||||
@ -272,7 +264,6 @@ class BasicCache:
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_store(provider, context, cache_value):
|
||||
"""Wrapper for async provider.on_store with error handling."""
|
||||
from comfy_execution.cache_provider import _logger
|
||||
try:
|
||||
await provider.on_store(context, cache_value)
|
||||
@ -280,7 +271,6 @@ class BasicCache:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||
|
||||
async def _check_providers_lookup(self, node_id, cache_key):
|
||||
"""Check external providers for cached result."""
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_nan, _logger
|
||||
@ -309,7 +299,6 @@ class BasicCache:
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
# Import CacheEntry here to avoid circular import at module level
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
@ -318,11 +307,9 @@ class BasicCache:
|
||||
return None
|
||||
|
||||
def _is_external_cacheable_value(self, value):
|
||||
"""Check if value is a CacheEntry suitable for external caching (not objects cache)."""
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
"""Get class_type for a node."""
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
@ -331,7 +318,6 @@ class BasicCache:
|
||||
return ''
|
||||
|
||||
def _build_context(self, node_id, cache_key):
|
||||
"""Build CacheContext with hash. Returns None if hashing fails."""
|
||||
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
||||
try:
|
||||
cache_key_hash = _serialize_cache_key(cache_key)
|
||||
@ -352,8 +338,8 @@ class BasicCache:
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
subcache._is_subcache = True # Mark as subcache - excludes from external caching
|
||||
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID
|
||||
subcache._is_subcache = True
|
||||
subcache._current_prompt_id = self._current_prompt_id
|
||||
self.subcaches[subcache_key] = subcache
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
@ -685,7 +685,6 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
"""Notify external cache providers of prompt lifecycle events."""
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger
|
||||
|
||||
if not _has_cache_providers():
|
||||
@ -716,11 +715,9 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
# Set prompt ID on caches for external provider integration
|
||||
for cache in self.caches.all:
|
||||
cache._current_prompt_id = prompt_id
|
||||
|
||||
# Notify external cache providers of prompt start
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
try:
|
||||
@ -785,7 +782,6 @@ class PromptExecutor:
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
# Notify external cache providers of prompt end
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user