refactor: move public types to comfy_api, eager provider snapshot

Address review feedback:
- Move CacheProvider/CacheContext/CacheValue definitions to
  comfy_api/latest/_caching.py (source of truth for public API)
- comfy_execution/cache_provider.py re-exports types from there
- Build _providers_snapshot eagerly on register/unregister instead
  of lazy memoization in _get_cache_providers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-03 18:19:23 -08:00
parent c73e3c9619
commit 8ed3386d3b
3 changed files with 55 additions and 60 deletions

View File

@ -135,11 +135,10 @@ class Caching:
Caching.register_provider(MyCacheProvider())
"""
# Import from comfy_execution.cache_provider (source of truth)
# Public types — defined in comfy_api.latest._caching (source of truth)
from ._caching import CacheProvider, CacheContext, CacheValue
# Registry functions — implementation in comfy_execution
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider as register_provider,
unregister_cache_provider as unregister_provider,
)

View File

@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
prompt_id: str
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@ -1,58 +1,19 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List
from dataclasses import dataclass
import hashlib
import json
import logging
import math
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue
_logger = logging.getLogger(__name__)
@dataclass
class CacheContext:
prompt_id: str
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
@ -63,7 +24,7 @@ def register_cache_provider(provider: CacheProvider) -> None:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = None
_providers_snapshot = tuple(_providers)
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
@ -72,33 +33,25 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = None
_providers_snapshot = tuple(_providers)
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
global _providers_snapshot
snapshot = _providers_snapshot
if snapshot is not None:
return snapshot
with _providers_lock:
if _providers_snapshot is not None:
return _providers_snapshot
_providers_snapshot = tuple(_providers)
return _providers_snapshot
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers)
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = None
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any: