refactor: move public types to comfy_api, eager provider snapshot

Address review feedback:
- Move CacheProvider/CacheContext/CacheValue definitions to
  comfy_api/latest/_caching.py (source of truth for public API)
- comfy_execution/cache_provider.py re-exports types from there
- Build _providers_snapshot eagerly on register/unregister instead
  of lazy memoization in _get_cache_providers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-03 18:19:23 -08:00
parent c73e3c9619
commit 8ed3386d3b
3 changed files with 55 additions and 60 deletions

View File

@ -135,11 +135,10 @@ class Caching:
Caching.register_provider(MyCacheProvider()) Caching.register_provider(MyCacheProvider())
""" """
# Import from comfy_execution.cache_provider (source of truth) # Public types — defined in comfy_api.latest._caching (source of truth)
from ._caching import CacheProvider, CacheContext, CacheValue
# Registry functions — implementation in comfy_execution
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider as register_provider, register_cache_provider as register_provider,
unregister_cache_provider as unregister_provider, unregister_cache_provider as unregister_provider,
) )

View File

@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
prompt_id: str
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@ -1,58 +1,19 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List from typing import Any, Optional, Tuple, List
from dataclasses import dataclass
import hashlib import hashlib
import json import json
import logging import logging
import math import math
import threading import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@dataclass
class CacheContext:
prompt_id: str
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass
_providers: List[CacheProvider] = [] _providers: List[CacheProvider] = []
_providers_lock = threading.Lock() _providers_lock = threading.Lock()
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None _providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None: def register_cache_provider(provider: CacheProvider) -> None:
@ -63,7 +24,7 @@ def register_cache_provider(provider: CacheProvider) -> None:
_logger.warning(f"Provider {provider.__class__.__name__} already registered") _logger.warning(f"Provider {provider.__class__.__name__} already registered")
return return
_providers.append(provider) _providers.append(provider)
_providers_snapshot = None _providers_snapshot = tuple(_providers)
_logger.info(f"Registered cache provider: {provider.__class__.__name__}") _logger.info(f"Registered cache provider: {provider.__class__.__name__}")
@ -72,33 +33,25 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
with _providers_lock: with _providers_lock:
try: try:
_providers.remove(provider) _providers.remove(provider)
_providers_snapshot = None _providers_snapshot = tuple(_providers)
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") _logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError: except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered") _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]: def _get_cache_providers() -> Tuple[CacheProvider, ...]:
global _providers_snapshot return _providers_snapshot
snapshot = _providers_snapshot
if snapshot is not None:
return snapshot
with _providers_lock:
if _providers_snapshot is not None:
return _providers_snapshot
_providers_snapshot = tuple(_providers)
return _providers_snapshot
def _has_cache_providers() -> bool: def _has_cache_providers() -> bool:
return bool(_providers) return bool(_providers_snapshot)
def _clear_cache_providers() -> None: def _clear_cache_providers() -> None:
global _providers_snapshot global _providers_snapshot
with _providers_lock: with _providers_lock:
_providers.clear() _providers.clear()
_providers_snapshot = None _providers_snapshot = ()
def _canonicalize(obj: Any) -> Any: def _canonicalize(obj: Any) -> Any: