feat(cache-provider): add on_set_prompt lifecycle hook for providers

Adds a new on_set_prompt() lifecycle hook on CacheProvider that fires
after the cache key set is prepared for a new prompt. Dispatched via
asyncio.create_task with errors swallowed (same fail-safe pattern as
on_store / on_lookup).

Why: BasicCache's lifecycle notifications to external providers were
incomplete. set_prompt is a key per-prompt event that providers need
visibility into — for example, to reset per-prompt timing/state used
for cost-aware caching policies (a provider can set t=0 here, then
measure elapsed at each on_store to estimate compute saved by a hit).

Backward-compatible: default implementation is a no-op, existing
providers compile and run unchanged. Providers that need the per-prompt
boundary override on_set_prompt().
This commit is contained in:
Deep Mehta 2026-05-20 21:15:35 -07:00
parent 95fdc6cf91
commit fcbe7db46f
2 changed files with 31 additions and 0 deletions

View File

@ -21,6 +21,10 @@ class CacheProvider(ABC):
Exceptions from provider methods are caught by the caller and never break execution. Exceptions from provider methods are caught by the caller and never break execution.
""" """
async def on_set_prompt(self) -> None:
"""Called after prompt cache keys are prepared. Dispatched via asyncio.create_task."""
pass
@abstractmethod @abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise.""" """Called on local cache miss. Return CacheValue if found, None otherwise."""

View File

@ -164,6 +164,7 @@ class BasicCache:
await self.cache_key_set.add_keys(node_ids) await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.initialized = True self.initialized = True
await self._notify_providers_set_prompt()
def all_node_ids(self): def all_node_ids(self):
assert self.initialized assert self.initialized
@ -263,6 +264,24 @@ class BasicCache:
except Exception as e: except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
async def _notify_providers_set_prompt(self):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
task = asyncio.create_task(self._safe_provider_set_prompt(provider))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on set_prompt: {e}")
@staticmethod @staticmethod
async def _safe_provider_store(provider, context, cache_value): async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger from comfy_execution.cache_provider import _logger
@ -271,6 +290,14 @@ class BasicCache:
except Exception as e: except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
@staticmethod
async def _safe_provider_set_prompt(provider):
from comfy_execution.cache_provider import _logger
try:
await provider.on_set_prompt()
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async set_prompt error: {e}")
async def _check_providers_lookup(self, node_id, cache_key): async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import ( from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers, _has_cache_providers, _get_cache_providers,