diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 742641dcb..84d520eb2 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -135,11 +135,10 @@ class Caching: Caching.register_provider(MyCacheProvider()) """ - # Import from comfy_execution.cache_provider (source of truth) + # Public types — defined in comfy_api.latest._caching (source of truth) + from ._caching import CacheProvider, CacheContext, CacheValue + # Registry functions — implementation in comfy_execution from comfy_execution.cache_provider import ( - CacheProvider, - CacheContext, - CacheValue, register_cache_provider as register_provider, unregister_cache_provider as unregister_provider, ) diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..686c99969 --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + prompt_id: str + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py index efa901446..1e597465a 100644 --- a/comfy_execution/cache_provider.py +++ b/comfy_execution/cache_provider.py @@ -1,58 +1,19 @@ -from abc import ABC, abstractmethod from typing import Any, Optional, Tuple, List -from dataclasses import dataclass import hashlib import json import logging import math import threading +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue + _logger = logging.getLogger(__name__) -@dataclass -class CacheContext: - prompt_id: str - node_id: str - class_type: str - cache_key_hash: str # SHA256 hex digest - - -@dataclass -class CacheValue: - outputs: list - ui: dict = None - - -class CacheProvider(ABC): - """Abstract base class for external cache providers. - Exceptions from provider methods are caught by the caller and never break execution. - """ - - @abstractmethod - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - """Called on local cache miss. Return CacheValue if found, None otherwise.""" - pass - - @abstractmethod - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - """Called after local store. Dispatched via asyncio.create_task.""" - pass - - def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: - """Return False to skip external caching for this node. Default: True.""" - return True - - def on_prompt_start(self, prompt_id: str) -> None: - pass - - def on_prompt_end(self, prompt_id: str) -> None: - pass - - _providers: List[CacheProvider] = [] _providers_lock = threading.Lock() -_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None +_providers_snapshot: Tuple[CacheProvider, ...] = () def register_cache_provider(provider: CacheProvider) -> None: @@ -63,7 +24,7 @@ def register_cache_provider(provider: CacheProvider) -> None: _logger.warning(f"Provider {provider.__class__.__name__} already registered") return _providers.append(provider) - _providers_snapshot = None + _providers_snapshot = tuple(_providers) _logger.info(f"Registered cache provider: {provider.__class__.__name__}") @@ -72,33 +33,25 @@ def unregister_cache_provider(provider: CacheProvider) -> None: with _providers_lock: try: _providers.remove(provider) - _providers_snapshot = None + _providers_snapshot = tuple(_providers) _logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") except ValueError: _logger.warning(f"Provider {provider.__class__.__name__} was not registered") def _get_cache_providers() -> Tuple[CacheProvider, ...]: - global _providers_snapshot - snapshot = _providers_snapshot - if snapshot is not None: - return snapshot - with _providers_lock: - if _providers_snapshot is not None: - return _providers_snapshot - _providers_snapshot = tuple(_providers) - return _providers_snapshot + return _providers_snapshot def _has_cache_providers() -> bool: - return bool(_providers) + return bool(_providers_snapshot) def _clear_cache_providers() -> None: global _providers_snapshot with _providers_lock: _providers.clear() - _providers_snapshot = None + _providers_snapshot = () def _canonicalize(obj: Any) -> Any: